From 230c4b38c10148c3fbbbb67cd1046766d73c865a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 May 2024 13:14:02 -0700 Subject: [PATCH 001/124] [CI/Test] fix swap test for multi gpu (#4689) --- tests/kernels/test_cache.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 8a27d51bb78d5..4cae15c79c489 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -222,11 +222,12 @@ def test_reshape_and_cache_flash( random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) + torch.set_default_device(device) # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device) qkv = torch.randn(num_tokens, 3, @@ -245,6 +246,7 @@ def test_reshape_and_cache_flash( head_size, kv_cache_dtype, dtype, + device=device, ) key_cache, value_cache = key_caches[0], value_caches[0] From 89579a201f2c84b512f0e1006ac2ea0d979803ab Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 8 May 2024 13:15:34 -0700 Subject: [PATCH 002/124] [Misc] Use vllm-flash-attn instead of flash-attn (#4686) --- Dockerfile | 21 --------------------- requirements-cuda.txt | 1 + setup.py | 14 +++++++++----- vllm/attention/backends/flash_attn.py | 2 +- vllm/attention/backends/flashinfer.py | 2 +- vllm/attention/selector.py | 7 ++++--- 6 files changed, 16 insertions(+), 31 deletions(-) diff --git a/Dockerfile b/Dockerfile index 90be3a30f89b1..ddca95c0e8786 100644 --- a/Dockerfile +++ b/Dockerfile @@ -87,23 +87,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \ pip cache remove vllm_nccl* #################### EXTENSION Build IMAGE #################### -#################### FLASH_ATTENTION Build IMAGE #################### -FROM dev as flash-attn-builder -# max jobs used for build -ARG max_jobs=2 -ENV MAX_JOBS=${max_jobs} -# flash attention version -ARG flash_attn_version=v2.5.8 -ENV FLASH_ATTN_VERSION=${flash_attn_version} - -WORKDIR /usr/src/flash-attention-v2 - -# Download the wheel or build it if a pre-compiled release doesn't exist -RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ - --no-build-isolation --no-deps --no-cache-dir - -#################### FLASH_ATTENTION Build IMAGE #################### - #################### vLLM installation IMAGE #################### # image with vLLM installed FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base @@ -122,10 +105,6 @@ RUN ldconfig /usr/local/cuda-12.4/compat/ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ pip install dist/*.whl --verbose - -RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ - --mount=type=cache,target=/root/.cache/pip \ - pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir #################### vLLM installation IMAGE #################### diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 6548d7a6684b2..ba8c614d205d2 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,3 +7,4 @@ nvidia-ml-py # for pynvml package vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.3.0 xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 +vllm-flash-attn == 2.5.8.post1 # Requires PyTorch 2.3.0 diff --git a/setup.py b/setup.py index 3768daf9d6fab..d9ba96b82329a 100644 --- a/setup.py +++ b/setup.py @@ -355,14 +355,18 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda(): requirements = _read_requirements("requirements-cuda.txt") - cuda_major = torch.version.cuda.split(".")[0] + cuda_major, cuda_minor = torch.version.cuda.split(".") modified_requirements = [] for req in requirements: if "vllm-nccl-cu12" in req: - modified_requirements.append( - req.replace("vllm-nccl-cu12", f"vllm-nccl-cu{cuda_major}")) - else: - modified_requirements.append(req) + req = req.replace("vllm-nccl-cu12", + f"vllm-nccl-cu{cuda_major}") + elif ("vllm-flash-attn" in req + and not (cuda_major == "12" and cuda_minor == "1")): + # vllm-flash-attn is built only for CUDA 12.1. + # Skip for other versions. + continue + modified_requirements.append(req) requirements = modified_requirements elif _is_hip(): requirements = _read_requirements("requirements-rocm.txt") diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c2fec9153f2d8..4bad226512b69 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -8,7 +8,7 @@ from typing import List, Optional, Tuple, Type import torch -from flash_attn import flash_attn_varlen_func +from vllm_flash_attn import flash_attn_varlen_func from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 8f13f3525512b..36e162671f944 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -3,8 +3,8 @@ import flashinfer import torch -from flash_attn import flash_attn_varlen_func from flashinfer import BatchDecodeWithPagedKVCacheWrapper +from vllm_flash_attn import flash_attn_varlen_func from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 34da0f6c6cdfc..f4446bac6b8d2 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -76,11 +76,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: return _Backend.XFORMERS try: - import flash_attn # noqa: F401 + import vllm_flash_attn # noqa: F401 except ImportError: logger.info( - "Cannot use FlashAttention-2 backend because the flash_attn " - "package is not found. Please install it for better performance.") + "Cannot use FlashAttention-2 backend because the vllm_flash_attn " + "package is not found. `pip install vllm-flash-attn` for better " + "performance.") return _Backend.XFORMERS backend_by_env_var = envs.VLLM_ATTENTION_BACKEND From f942efb5a3712498b8b583d2d9345f98d15f22f0 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 8 May 2024 14:44:00 -0700 Subject: [PATCH 003/124] [Dynamic Spec Decoding] Auto-disable by the running queue size (#4592) Co-authored-by: Cade Daniel --- tests/samplers/test_rejection_sampler.py | 13 +++- .../e2e/test_multistep_correctness.py | 34 ++++++++ .../spec_decode/e2e/test_ngram_correctness.py | 2 +- tests/spec_decode/test_dynamic_spec_decode.py | 77 +++++++++++++++++++ vllm/config.py | 29 +++++-- vllm/engine/arg_utils.py | 10 +++ vllm/executor/gpu_executor.py | 2 + .../layers/rejection_sampler.py | 11 ++- vllm/sequence.py | 6 ++ vllm/spec_decode/spec_decode_worker.py | 59 +++++++++----- vllm/spec_decode/top1_proposer.py | 23 ++++-- 11 files changed, 227 insertions(+), 39 deletions(-) create mode 100644 tests/spec_decode/test_dynamic_spec_decode.py diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 13b5b80cccfdc..00a2379502e6d 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -42,9 +42,11 @@ def mock_causal_accepted_tensor( @pytest.mark.parametrize( "which_tokens_accepted", ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"]) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_correct_output_format(which_tokens_accepted: str, seed: int, +def test_correct_output_format(which_tokens_accepted: str, + disable_bonus_tokens: bool, seed: int, device: str): """Verify the output has correct format given predetermined accepted matrix. """ @@ -82,7 +84,8 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, size=(batch_size, 1), dtype=torch.int64) - rejection_sampler = RejectionSampler() + rejection_sampler = RejectionSampler( + disable_bonus_tokens=disable_bonus_tokens) rejection_sampler.init_gpu_tensors(rank=0) output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access accepted, @@ -91,9 +94,11 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, bonus_token_ids, ) - # Bonus tokens are currently disabled. Verify they're set to -1. + expected_bonus_token_ids = bonus_token_ids.clone() + # If bonus tokens disabled. Verify they are set to -1. # See https://github.com/vllm-project/vllm/issues/4212 - expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1 + if disable_bonus_tokens: + expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1 if which_tokens_accepted == "all_tokens_accepted": # Expect all tokens to be equal to draft tokens. diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index f15fcc4746d20..94d71fb012727 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -536,6 +536,40 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator, force_output_len=True) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-160m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "speculative_disable_by_batch_size": 2, + }, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("output_len", [10]) +@pytest.mark.parametrize("seed", [1]) +def test_disable_speculation(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality when all sequences disable speculation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 44ef400c91d34..c2004ff061a1e 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -57,7 +57,7 @@ @pytest.mark.parametrize("output_len", [ 256, ]) -@pytest.mark.parametrize("batch_size", [1, 64]) +@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) def test_ngram_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, batch_size: int, diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py new file mode 100644 index 0000000000000..948a74b22f0ae --- /dev/null +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -0,0 +1,77 @@ +from unittest.mock import MagicMock + +import pytest +import torch + +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.metrics import AsyncMetricsCollector +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker +from vllm.spec_decode.top1_proposer import Top1Proposer + +from .utils import create_batch, mock_worker + + +@pytest.mark.parametrize('queue_size', [2, 4]) +@pytest.mark.parametrize('batch_size', [1, 2, 3, 6]) +@pytest.mark.parametrize('k', [1, 2, 5, 7, 10]) +@torch.inference_mode() +def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): + """Verify that speculative tokens are disabled when the batch size + exceeds the threshold. + """ + disable_by_batch_size = 3 + + draft_worker = mock_worker(cls=MultiStepWorker) + target_worker = mock_worker() + rejection_sampler = MagicMock(spec=RejectionSampler) + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + worker = SpecDecodeWorker(proposer_worker=draft_worker, + scorer_worker=target_worker, + rejection_sampler=rejection_sampler, + metrics_collector=metrics_collector, + disable_by_batch_size=disable_by_batch_size) + + exception_secret = 'artificial stop' + draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) + + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k, + running_queue_size=queue_size) + + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_req=execute_model_req) + + # When the batch size is larger than the threshold, + # we expect no speculative tokens (0). + expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0 + assert seq_group_metadata_list[ + 0].num_speculative_tokens == expected_num_spec_tokens + + draft_worker.sampler_output.side_effect = ValueError(exception_secret) + + proposer = Top1Proposer( + worker=draft_worker, + device='cpu', # not used + vocab_size=100, # not used + # Must be long enough to avoid being skipped due to length. + max_proposal_len=1024, + ) + + if queue_size < disable_by_batch_size: + # Should raise exception when executing the mocked draft model. + with pytest.raises(ValueError, match=exception_secret): + proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) + else: + # Should not execute the draft model because spec decode is disabled + # for all requests. Accordingly, the proposal length should be 0. + proposals = proposer.get_proposals( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) + assert proposals.proposal_lens.tolist() == [0] * batch_size diff --git a/vllm/config.py b/vllm/config.py index 5c3a8615eefb4..a2cb9b32c65fc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -692,6 +692,7 @@ def maybe_create_spec_config( speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, use_v2_block_manager: bool, + speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], ) -> Optional["SpeculativeConfig"]: @@ -720,6 +721,9 @@ def maybe_create_spec_config( use_v2_block_manager (bool): Whether vLLM is configured to use the v2 block manager or not. Used for raising an error since the v2 block manager is required with spec decode. + speculative_disable_by_batch_size (Optional[int]): Disable + speculative decoding for new incoming requests when the number + of enqueue requests is larger than this value, if provided. ngram_prompt_lookup_max (Optional[int]): Max size of ngram token window, if provided. ngram_prompt_lookup_min (Optional[int]): Min size of ngram token @@ -730,7 +734,7 @@ def maybe_create_spec_config( the necessary conditions are met, else None. """ - if (speculative_model is None and num_speculative_tokens is None): + if speculative_model is None and num_speculative_tokens is None: return None if speculative_model is not None and num_speculative_tokens is None: @@ -739,6 +743,12 @@ def maybe_create_spec_config( "num_speculative_tokens to be provided, but found " f"{speculative_model=} and {num_speculative_tokens=}.") + if (speculative_disable_by_batch_size is not None + and speculative_disable_by_batch_size < 2): + raise ValueError("Expect the batch size threshold of disabling " + "speculative decoding is > 1, but got " + f"{speculative_disable_by_batch_size=}") + assert (speculative_model is not None and num_speculative_tokens is not None) @@ -807,6 +817,7 @@ def maybe_create_spec_config( draft_model_config, draft_parallel_config, num_speculative_tokens, + speculative_disable_by_batch_size, ngram_prompt_lookup_max, ngram_prompt_lookup_min, ) @@ -876,8 +887,9 @@ def __init__( draft_model_config: ModelConfig, draft_parallel_config: ParallelConfig, num_speculative_tokens: int, - ngram_prompt_lookup_max: int, - ngram_prompt_lookup_min: int, + speculative_disable_by_batch_size: Optional[int], + ngram_prompt_lookup_max: Optional[int], + ngram_prompt_lookup_min: Optional[int], ): """Create a SpeculativeConfig object. @@ -886,12 +898,19 @@ def __init__( draft_parallel_config: ParallelConfig for the draft model. num_speculative_tokens: The number of tokens to sample from the draft model before scoring with the target model. + speculative_disable_by_batch_size: Disable speculative + decoding for new incoming requests when the number of + enqueue requests is larger than this value. + ngram_prompt_lookup_max: Max size of ngram token window. + ngram_prompt_lookup_min: Min size of ngram token window. """ self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config self.num_speculative_tokens = num_speculative_tokens - self.ngram_prompt_lookup_max = ngram_prompt_lookup_max - self.ngram_prompt_lookup_min = ngram_prompt_lookup_min + self.speculative_disable_by_batch_size = \ + speculative_disable_by_batch_size + self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 + self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0 self._verify_args() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bb8245eb307f7..c99b1806c7d1d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -83,6 +83,7 @@ class EngineArgs: speculative_model: Optional[str] = None num_speculative_tokens: Optional[int] = None speculative_max_model_len: Optional[int] = None + speculative_disable_by_batch_size: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None @@ -467,6 +468,13 @@ def add_cli_args( 'draft model. Sequences over this length will skip ' 'speculation.') + parser.add_argument( + '--speculative-disable-by-batch-size', + type=int, + default=EngineArgs.speculative_disable_by_batch_size, + help='Disable speculative decoding for new incoming requests ' + 'if the number of enqueue requests is larger than this value.') + parser.add_argument( '--ngram-prompt-lookup-max', type=int, @@ -547,6 +555,8 @@ def create_engine_config(self, ) -> EngineConfig: target_dtype=self.dtype, speculative_model=self.speculative_model, num_speculative_tokens=self.num_speculative_tokens, + speculative_disable_by_batch_size=self. + speculative_disable_by_batch_size, speculative_max_model_len=self.speculative_max_model_len, enable_chunked_prefill=self.enable_chunked_prefill, use_v2_block_manager=self.use_v2_block_manager, diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index e8559b6a5c0fe..fa3480fa64837 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -93,6 +93,8 @@ def _init_spec_worker(self): spec_decode_worker = SpecDecodeWorker.create_worker( scorer_worker=target_worker, draft_worker_kwargs=draft_worker_kwargs, + disable_by_batch_size=self.speculative_config. + speculative_disable_by_batch_size, ) assert self.parallel_config.world_size == 1, ( diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 5edbbf2c70a49..b5f1e55d0e839 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -12,15 +12,21 @@ class RejectionSampler(nn.Module): https://arxiv.org/pdf/2302.01318.pdf. """ - def __init__(self, strict_mode: bool = False): + def __init__(self, + disable_bonus_tokens: bool = True, + strict_mode: bool = False): """Create a rejection sampler. Args: + disable_bonus_tokens: Whether or not to disable the bonus token. + Require when bonus tokens will cause corrupt KV cache for + proposal methods that require KV cache. strict_mode: Whether or not to perform shape/device/dtype checks during sampling. This catches correctness issues but adds nontrivial latency. """ super().__init__() + self._disable_bonus_tokens = disable_bonus_tokens self._strict_mode = strict_mode # NOTE: A "bonus token" is accepted iff all proposal tokens are @@ -312,7 +318,8 @@ def _create_output( # proposal methods that require KV cache. We can fix it by "prefilling" # the bonus token in the proposer. The following issue tracks the fix. # https://github.com/vllm-project/vllm/issues/4212 - output_with_bonus_tokens[:, -1] = -1 + if self._disable_bonus_tokens: + output_with_bonus_tokens[:, -1] = -1 # Fill the recovered token ids. output.mul_(~after_false_mask).add_( diff --git a/vllm/sequence.py b/vllm/sequence.py index 42b508b517200..3cebb85b49d27 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -612,6 +612,12 @@ def __init__( self._token_chunk_size = token_chunk_size self.do_sample = do_sample + # The number of speculative tokens adopted in this request. + # None means specuative decoding is not used. + # Zero means speculative decoding is disabled for some reasons. + # TODO: We should maintain this states out of the sequence group. + self.num_speculative_tokens = None + if self._token_chunk_size is None: if is_prompt: self._token_chunk_size = list(seq_data.values())[0].get_len() diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 84ec974806c7e..a4e759095b294 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch @@ -54,7 +54,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): def create_worker( cls, scorer_worker: WorkerBase, - draft_worker_kwargs, + draft_worker_kwargs: Dict[str, Any], + disable_by_batch_size: Optional[int], ) -> "SpecDecodeWorker": ngram_prompt_lookup_max = ( @@ -62,7 +63,9 @@ def create_worker( ngram_prompt_lookup_min = ( draft_worker_kwargs.pop("ngram_prompt_lookup_min")) + disable_bonus_tokens = True if ngram_prompt_lookup_max > 0: + disable_bonus_tokens = False proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, ngram_prompt_lookup_max) @@ -75,9 +78,9 @@ def create_worker( return SpecDecodeWorker( proposer_worker, scorer_worker, - # TODO(cade) disable strict mode for speedup. - rejection_sampler=RejectionSampler(strict_mode=True), - ) + disable_by_batch_size=disable_by_batch_size, + rejection_sampler=RejectionSampler( + disable_bonus_tokens=disable_bonus_tokens, )) def __init__( self, @@ -85,6 +88,7 @@ def __init__( scorer_worker: WorkerBase, rejection_sampler: RejectionSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, + disable_by_batch_size: Optional[int] = None, ): """ Create a SpecDecodeWorker. @@ -97,11 +101,14 @@ def __init__( Worker. rejection_sampler: A Torch module used to perform modified rejection sampling for speculative decoding. + disable_by_batch_size: If the batch size is larger than this, + disable speculative decoding for new incoming requests. metrics_collector: Helper class for collecting metrics; can be set for testing purposes. """ self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker + self.disable_by_batch_size = disable_by_batch_size or float("inf") self.rejection_sampler = rejection_sampler self._metrics = AsyncMetricsCollector( @@ -199,27 +206,41 @@ def execute_model( "speculative decoding " "requires non-None seq_group_metadata_list") + # When the batch size is too large, disable speculative decoding + # to stop trading off throughput for latency. + disable_all = (execute_model_req.running_queue_size >= + self.disable_by_batch_size) + if disable_all: + for seq_group_metadata in execute_model_req.seq_group_metadata_list: + # Once num_speculative_tokens is set to 0, the spec decode + # of this request will be disabled forever. + # TODO(comaniac): We currently store spec decoding specific + # state in the global data structure, but we should maintain + # this state within spec decode worker. + seq_group_metadata.num_speculative_tokens = 0 + # If no spec tokens, call the proposer and scorer workers normally. - # Used for prefill. + # This happens for prefill, or when the spec decode is disabled + # for this batch. if execute_model_req.num_lookahead_slots == 0 or len( execute_model_req.seq_group_metadata_list) == 0: - return self._run_no_spec(execute_model_req) + return self._run_no_spec(execute_model_req, + skip_proposer=disable_all) return self._run_speculative_decoding_step(execute_model_req) @nvtx_range("spec_decode_worker._run_no_spec") - def _run_no_spec( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - """Run a prefill step, without any speculation. The input is sent to the - proposer and scorer model so that the KV cache is consistent between the - two. + def _run_no_spec(self, execute_model_req: ExecuteModelRequest, + skip_proposer: bool) -> List[SamplerOutput]: + """Run a prefill step, without any speculation. The input is sent to + the proposer and scorer model so that the KV cache is consistent + between the two. When skip_proposer is True, the proposer model is + not called, meaning that the kv-cache in proposer for requests is not + updated, so they cannot enable spec decode in the rest decoding. """ - #logger.info("run proposer worker no spec") - - self.proposer_worker.execute_model(execute_model_req) + if not skip_proposer: + self.proposer_worker.execute_model(execute_model_req) - #logger.info("run target worker no spec") sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 sampler_output = sampler_output[0] @@ -244,22 +265,18 @@ def _run_speculative_decoding_step( sequence. """ - #logger.info("get spec proposals") # Generate proposals using draft worker. proposals = self.proposer_worker.get_spec_proposals(execute_model_req) - #logger.info("score proposals") proposal_scores = self.scorer.score_proposals( execute_model_req, proposals, ) - #logger.info("verify proposals") accepted_token_ids, target_logprobs = self._verify_tokens( execute_model_req.seq_group_metadata_list, proposal_scores, proposals, execute_model_req.num_lookahead_slots) - #logger.info("create output list") return self._create_output_sampler_list( execute_model_req.seq_group_metadata_list, accepted_token_ids, diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index eb622a0e2e7f4..ee9462b68dae8 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -56,7 +56,7 @@ def get_proposals( proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices, - ) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len) + ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len) if nonzero_proposal_len_seqs: # Speculate tokens using the draft worker for the speculative @@ -97,17 +97,27 @@ def get_proposals( return proposals - def _split_by_max_model_len( + def _split_by_proposal_len( self, seq_group_metadata_list: List[SequenceGroupMetadata], proposal_len: int, ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: - """Determine which sequences would exceed the max model length.""" + """Split sequences by two groups: + 1. Sequences with non-zero proposal length. + 2. Sequences with zero proposal length (due to disabled speculation + or exceed the maximum model length). + """ proposal_lens: List[int] = [] nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] nonzero_proposal_len_indices: List[int] = [] for i, seq_group_metadata in enumerate(seq_group_metadata_list): + # The speculative decoding for this request has been disabled + # (e.g. due to high traffic). + if seq_group_metadata.num_speculative_tokens == 0: + proposal_lens.append(0) + continue + seq_data = next(iter(seq_group_metadata.seq_data.values())) seq_len = seq_data.get_len() @@ -115,13 +125,14 @@ def _split_by_max_model_len( # are supported. # If max_proposal_len is defined, then we shall no exccess this # quota for nonzero_proposal + new_k = 0 if (self.max_proposal_len is None or seq_len + proposal_len < self.max_proposal_len): - proposal_lens.append(proposal_len) + new_k = proposal_len nonzero_proposal_len_seqs.append(seq_group_metadata) nonzero_proposal_len_indices.append(i) - else: - proposal_lens.append(0) + proposal_lens.append(new_k) + seq_group_metadata.num_speculative_tokens = new_k return ( proposal_lens, From 8b9241be3a0020724e145bf600d9710b3d59b167 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Wed, 8 May 2024 16:24:46 -0700 Subject: [PATCH 004/124] [Speculative decoding] [Bugfix] Fix overallocation in ngram + spec logprobs (#4672) --- vllm/spec_decode/ngram_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index fed8be42054a5..f18f9387f5b23 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -138,7 +138,7 @@ def sampler_output( SamplerOutput( outputs=None, sampled_token_probs=token_probs[i], - logprobs=token_logprobs, + logprobs=token_logprobs[i], sampled_token_ids=token_ids[i], )) return outputs, False From e288df0632d5bdde76c20bed8310b46d35b8e5ac Mon Sep 17 00:00:00 2001 From: alexm-nm <59768536+alexm-nm@users.noreply.github.com> Date: Wed, 8 May 2024 20:14:31 -0400 Subject: [PATCH 005/124] [Bugfix] Fine-tune gptq_marlin configs to be more similar to marlin (#4626) --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 48 ++++++++++++++------ 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index fd0837f0cb39c..9c6bff000e916 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -115,7 +115,8 @@ template __device__ inline int lop3(int a, int b, int c) { return res; } -// Constructs destination register by taking bytes from 2 sources (based on mask) +// Constructs destination register by taking bytes from 2 sources (based on +// mask) template __device__ inline uint32_t prmt(uint32_t a) { uint32_t res; @@ -933,9 +934,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk }; // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped partitioning - // minimizes the number of such reductions and our outputs are usually rather - // small, we perform this reduction serially in L2 cache. + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out @@ -1275,13 +1276,22 @@ typedef struct { thread_config_t tb_cfg; } exec_config_t; -thread_config_t thread_configs[] = { +thread_config_t small_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {64, 256, 256}, // Default (max cache usage) - {64, 128, 128}, // Reduce N, reduce warps - {128, 64, 128}, // Reduce N more, but increase K + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}, +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}, }; @@ -1397,11 +1407,21 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, int max_shared_mem) { int max_m_blocks = 4; while (max_m_blocks > 0) { - for (auto th_config : thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } } } @@ -1574,10 +1594,12 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, } CALL_IF(4, 32, 2, 256) CALL_IF(4, 16, 4, 256) + CALL_IF(4, 8, 8, 256) CALL_IF(4, 8, 4, 128) CALL_IF(4, 4, 8, 128) CALL_IF(8, 32, 2, 256) CALL_IF(8, 16, 4, 256) + CALL_IF(8, 8, 8, 256) CALL_IF(8, 8, 4, 128) CALL_IF(8, 4, 8, 128) else { From 16bc0a098f6d34050637c3336183fb6966300dd5 Mon Sep 17 00:00:00 2001 From: Mahmoud Ashraf Date: Thu, 9 May 2024 08:02:31 +0300 Subject: [PATCH 006/124] [Frontend] add tok/s speed metric to llm class when using tqdm (#4400) Co-authored-by: Michael Goin --- vllm/entrypoints/llm.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 3ed660e183360..71620139fba39 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -238,17 +238,25 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() - pbar = tqdm(total=num_requests, - desc="Processed prompts", - dynamic_ncols=True) + pbar = tqdm( + total=num_requests, + desc="Processed prompts", + dynamic_ncols=True, + postfix=f"Generation Speed: {0:.2f} toks/s", + ) # Run the engine. outputs: List[RequestOutput] = [] + total_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() for output in step_outputs: if output.finished: outputs.append(output) if use_tqdm: + total_toks += (sum( + len(stp.token_ids) for stp in output.outputs)) + spd = total_toks / pbar.format_dict["elapsed"] + pbar.postfix = f"Generation Speed: {spd:.2f} toks/s" pbar.update(1) if use_tqdm: pbar.close() @@ -256,4 +264,4 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # This is necessary because some requests may be finished earlier than # its previous requests. outputs = sorted(outputs, key=lambda x: int(x.request_id)) - return outputs \ No newline at end of file + return outputs From f12b20deccbc6c8bb5cdeac053d75178341c66c1 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 9 May 2024 13:48:33 +0800 Subject: [PATCH 007/124] [Frontend] Move async logic outside of constructor (#4674) --- tests/async_engine/test_chat_template.py | 30 ++++---- tests/entrypoints/openai/test_serving_chat.py | 8 ++- vllm/engine/arg_utils.py | 2 +- vllm/entrypoints/openai/api_server.py | 23 +++++- vllm/entrypoints/openai/serving_chat.py | 72 +++++++++---------- vllm/entrypoints/openai/serving_completion.py | 7 +- vllm/entrypoints/openai/serving_engine.py | 56 +++++---------- 7 files changed, 96 insertions(+), 102 deletions(-) diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index 64bcba67c3437..55b730812ea94 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -60,13 +60,12 @@ class MockServingChat: tokenizer: MockTokenizer -@pytest.mark.asyncio -async def test_load_chat_template(): +def test_load_chat_template(): # Testing chatml template tokenizer = MockTokenizer() mock_serving_chat = MockServingChat(tokenizer) - await OpenAIServingChat._load_chat_template( - mock_serving_chat, chat_template=chatml_jinja_path) + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=chatml_jinja_path) template_content = tokenizer.chat_template @@ -77,8 +76,7 @@ async def test_load_chat_template(): {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501 -@pytest.mark.asyncio -async def test_no_load_chat_template_filelike(): +def test_no_load_chat_template_filelike(): # Testing chatml template template = "../../examples/does_not_exist" tokenizer = MockTokenizer() @@ -86,35 +84,33 @@ async def test_no_load_chat_template_filelike(): mock_serving_chat = MockServingChat(tokenizer) with pytest.raises(ValueError, match="looks like a file path"): - await OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) -@pytest.mark.asyncio -async def test_no_load_chat_template_literallike(): +def test_no_load_chat_template_literallike(): # Testing chatml template template = "{{ messages }}" tokenizer = MockTokenizer() mock_serving_chat = MockServingChat(tokenizer) - await OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) template_content = tokenizer.chat_template assert template_content == template -@pytest.mark.asyncio @pytest.mark.parametrize( "model,template,add_generation_prompt,expected_output", MODEL_TEMPLATE_GENERATON_OUTPUT) -async def test_get_gen_prompt(model, template, add_generation_prompt, - expected_output): +def test_get_gen_prompt(model, template, add_generation_prompt, + expected_output): # Initialize the tokenizer tokenizer = get_tokenizer(tokenizer_name=model) mock_serving_chat = MockServingChat(tokenizer) - await OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) # Create a mock request object using keyword arguments mock_request = ChatCompletionRequest( diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 269b0823fec05..13e2e372cef33 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -20,11 +20,15 @@ class MockModelConfig: class MockEngine: async def get_model_config(self): - return MockModelConfig + return MockModelConfig() async def _async_serving_chat_init(): - serving_completion = OpenAIServingChat(MockEngine(), + engine = MockEngine() + model_config = await engine.get_model_config() + + serving_completion = OpenAIServingChat(engine, + model_config, served_model_names=[MODEL_NAME], response_role="assistant", chat_template=CHAT_TEMPLATE) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c99b1806c7d1d..5c2acbef13129 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -516,7 +516,7 @@ def add_cli_args( return parser @classmethod - def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': + def from_cli_args(cls, args: argparse.Namespace): # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 44a946f2e32d4..362f28d05c3bb 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -4,7 +4,7 @@ import re from contextlib import asynccontextmanager from http import HTTPStatus -from typing import Set +from typing import Optional, Set import fastapi import uvicorn @@ -164,15 +164,32 @@ async def authentication(request: Request, call_next): served_model_names = args.served_model_name else: served_model_names = [args.model] + engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - openai_serving_chat = OpenAIServingChat(engine, served_model_names, + + event_loop: Optional[asyncio.AbstractEventLoop] + try: + event_loop = asyncio.get_running_loop() + except RuntimeError: + event_loop = None + + if event_loop is not None and event_loop.is_running(): + # If the current is instanced by Ray Serve, + # there is already a running event loop + model_config = event_loop.run_until_complete(engine.get_model_config()) + else: + # When using single vLLM without engine_use_ray + model_config = asyncio.run(engine.get_model_config()) + + openai_serving_chat = OpenAIServingChat(engine, model_config, + served_model_names, args.response_role, args.lora_modules, args.chat_template) openai_serving_completion = OpenAIServingCompletion( - engine, served_model_names, args.lora_modules) + engine, model_config, served_model_names, args.lora_modules) app.root_path = args.root_path uvicorn.run(app, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c8f4a6b315db0..1b469fc59b076 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,4 +1,3 @@ -import asyncio import codecs import time from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List, @@ -8,6 +7,7 @@ from openai.types.chat import (ChatCompletionContentPartParam, ChatCompletionRole) +from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -35,17 +35,47 @@ class OpenAIServingChat(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, + model_config: ModelConfig, served_model_names: List[str], response_role: str, lora_modules: Optional[List[LoRAModulePath]] = None, chat_template: Optional[str] = None): super().__init__(engine=engine, + model_config=model_config, served_model_names=served_model_names, - lora_modules=lora_modules, - await_post_init=self._load_chat_template( - chat_template=chat_template)) + lora_modules=lora_modules) self.response_role = response_role + self._load_chat_template(chat_template) + + def _load_chat_template(self, chat_template: Optional[str]): + tokenizer = self.tokenizer + + if chat_template is not None: + try: + with open(chat_template, "r") as f: + tokenizer.chat_template = f.read() + except OSError as e: + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = (f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}") + raise ValueError(msg) from e + + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + tokenizer.chat_template = codecs.decode( + chat_template, "unicode_escape") + + logger.info("Using supplied chat template:\n%s", + tokenizer.chat_template) + elif tokenizer.chat_template is not None: + logger.info("Using default chat template:\n%s", + tokenizer.chat_template) + else: + logger.warning( + "No chat template provided. Chat API will not work.") def _parse_chat_message_content( self, @@ -357,36 +387,4 @@ async def chat_completion_full_generator( usage=usage, ) - return response - - async def _load_chat_template(self, chat_template: Optional[str]): - while self.tokenizer is None: - # Give the parent class time to load the tokenizer - await asyncio.sleep(0.1) - tokenizer = self.tokenizer - - if chat_template is not None: - try: - with open(chat_template, "r") as f: - tokenizer.chat_template = f.read() - except OSError as e: - JINJA_CHARS = "{}\n" - if not any(c in chat_template for c in JINJA_CHARS): - msg = (f"The supplied chat template ({chat_template}) " - f"looks like a file path, but it failed to be " - f"opened. Reason: {e}") - raise ValueError(msg) from e - - # If opening a file fails, set chat template to be args to - # ensure we decode so our escape are interpreted correctly - tokenizer.chat_template = codecs.decode( - chat_template, "unicode_escape") - - logger.info("Using supplied chat template:\n%s", - tokenizer.chat_template) - elif tokenizer.chat_template is not None: - logger.info("Using default chat template:\n%s", - tokenizer.chat_template) - else: - logger.warning( - "No chat template provided. Chat API will not work.") + return response \ No newline at end of file diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 6a7f29c4c96f2..158d8ed7fbbf5 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -4,6 +4,7 @@ from fastapi import Request +from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (CompletionRequest, CompletionResponse, @@ -52,11 +53,11 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]: class OpenAIServingCompletion(OpenAIServing): - def __init__(self, - engine: AsyncLLMEngine, + def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]] = None): + lora_modules: Optional[List[LoRAModulePath]]): super().__init__(engine=engine, + model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 21baea2e5e7f6..f10718c5f3d80 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,13 +1,12 @@ -import asyncio import json from dataclasses import dataclass from http import HTTPStatus -from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from pydantic import Field -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from typing_extensions import Annotated +from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, ErrorResponse, @@ -29,13 +28,24 @@ class LoRAModulePath: class OpenAIServing: - def __init__(self, - engine: AsyncLLMEngine, + def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]], - await_post_init: Optional[Awaitable[Any]] = None): + lora_modules: Optional[List[LoRAModulePath]]): + super().__init__() + self.engine = engine + self.max_model_len = model_config.max_model_len + + # A separate tokenizer to map token IDs to strings. + self.tokenizer = get_tokenizer( + model_config.tokenizer, + tokenizer_mode=model_config.tokenizer_mode, + tokenizer_revision=model_config.tokenizer_revision, + trust_remote_code=model_config.trust_remote_code, + truncation_side="left") + self.served_model_names = served_model_names + if lora_modules is None: self.lora_requests = [] else: @@ -47,38 +57,6 @@ def __init__(self, ) for i, lora in enumerate(lora_modules, start=1) ] - self.max_model_len = 0 - # Lazy initialized - self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - - try: - event_loop = asyncio.get_running_loop() - except RuntimeError: - event_loop = None - - if event_loop is not None and event_loop.is_running(): - # If the current is instanced by Ray Serve, - # there is already a running event loop - event_loop.create_task(self._post_init(await_post_init)) - else: - # When using single vLLM without engine_use_ray - asyncio.run(self._post_init(await_post_init)) - - async def _post_init(self, await_post_init): - engine_model_config = await self.engine.get_model_config() - self.max_model_len = engine_model_config.max_model_len - - # A separate tokenizer to map token IDs to strings. - self.tokenizer = get_tokenizer( - engine_model_config.tokenizer, - tokenizer_mode=engine_model_config.tokenizer_mode, - tokenizer_revision=engine_model_config.tokenizer_revision, - trust_remote_code=engine_model_config.trust_remote_code, - truncation_side="left") - - if await_post_init is not None: - await await_post_init - async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ From 190bc838e17196733526896bf2861f8d05bd3f43 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 9 May 2024 00:17:17 -0700 Subject: [PATCH 008/124] [Misc] Remove unnecessary ModelRunner imports (#4703) --- tests/samplers/test_sampler.py | 81 ++++++++++------------------------ tests/test_logits_processor.py | 23 +++------- 2 files changed, 31 insertions(+), 73 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index e4fea165a4d46..ddc66aa28a094 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -11,8 +11,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import Counter -from vllm.worker.model_runner import ModelRunner +from vllm.utils import Counter, is_pin_memory_available class MockLogitsSampler(Sampler): @@ -26,20 +25,14 @@ def forward(self, *args, **kwargs): def _prepare_test( - batch_size: int -) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]: + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]: input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) fake_logits = torch.full((batch_size, VOCAB_SIZE), 1e-2, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) - return input_tensor, fake_logits, sampler, model_runner + return input_tensor, fake_logits, sampler VOCAB_SIZE = 32000 @@ -53,7 +46,6 @@ def _do_sample( batch_size: int, input_tensor: torch.Tensor, sampler: MockLogitsSampler, - model_runner: ModelRunner, sampling_params: SamplingParams, device: str, ): @@ -75,7 +67,7 @@ def _do_sample( seq_lens, query_lens=seq_lens, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) @@ -85,19 +77,16 @@ def test_sampler_all_greedy(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams(temperature=0) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampler_output = _do_sample(batch_size, fake_logits, sampler, sampling_params, device) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == expected[i].item() - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -105,8 +94,7 @@ def test_sampler_all_random(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 @@ -115,15 +103,13 @@ def test_sampler_all_random(seed: int, device: str): temperature=1.0, n=random.randint(1, 10), ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampler_output = _do_sample(batch_size, fake_logits, sampler, sampling_params, device) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -131,7 +117,7 @@ def test_sampler_all_random_seed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 @@ -141,15 +127,13 @@ def test_sampler_all_random_seed(seed: int, device: str): n=random.randint(1, 10), seed=random.randint(0, 10000), ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampler_output = _do_sample(batch_size, fake_logits, sampler, sampling_params, device) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -157,7 +141,7 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=1.0, @@ -165,15 +149,13 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): seed=random.randint(0, 10000), ) first_sampler_output = _do_sample(batch_size, fake_logits, sampler, - model_runner, sampling_params, device) + sampling_params, device) second_sampler_output = _do_sample(batch_size, fake_logits, sampler, - model_runner, sampling_params, device) + sampling_params, device) assert first_sampler_output == second_sampler_output - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -181,20 +163,18 @@ def test_sampler_all_beam(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=0, best_of=2, use_beam_search=True, ) - _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params, - device) + _do_sample(batch_size, fake_logits, sampler, sampling_params, device) # no assertion here as I am not sure how to determine whether # the outputs are expected - in other words, this just tests # whether there are no exceptions in the sampler # when handling an all-beam search case. - del model_runner @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -448,13 +428,13 @@ def run_test_case(*, ("Invalid test case, expected_penalization does not match computed" "batch size") - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens=seq_lens if seq_lens else None, query_lens=seq_lens if seq_lens else None, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) # the logits tensor is modified in-place by the sampler _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -480,8 +460,6 @@ def run_test_case(*, fake_logits[logits_idx, :] == -float('inf')) == 0, "No tokens should have been penalized" - del model_runner - for test_case in test_cases: run_test_case(**test_case) @@ -492,8 +470,7 @@ def test_sampler_mixed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, sampler = _prepare_test(batch_size) seq_group_metadata_list = [] expected_tokens: List[Optional[List[int]]] = [] @@ -534,13 +511,13 @@ def test_sampler_mixed(seed: int, device: str): )) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - def test_sampling(model_runner: ModelRunner): + def test_sampling(): sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens=seq_lens, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) sampler_output = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -570,7 +547,7 @@ def test_sampling(model_runner: ModelRunner): assert nth_output.output_token in expected_tokens[i] # Test batch - test_sampling(model_runner) + test_sampling() # Shuffle the batch and resample target_index = list(range(batch_size)) @@ -583,9 +560,7 @@ def test_sampling(model_runner: ModelRunner): # This time, results of seeded random samples will be compared with # the corresponding sample in the pre-shuffled batch - test_sampling(model_runner) - - del model_runner + test_sampling() @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -605,12 +580,6 @@ def test_sampler_top_k_top_p(seed: int, device: str): device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) generation_model = GenerationMixin() generation_config = GenerationConfig(top_k=top_k, @@ -641,7 +610,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): seq_lens, query_lens=seq_lens, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) sample_probs = None @@ -657,5 +626,3 @@ def mock_sample(probs, *args, **kwargs): hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) assert torch.allclose(hf_probs, sample_probs, atol=1e-5) assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) - - del model_runner diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 179e8d25a341b..4ee980505a3ab 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -9,7 +9,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import ModelRunner +from vllm.utils import is_pin_memory_available class MockLogitsProcessor(LogitsProcessor): @@ -30,21 +30,15 @@ def forward(self, *args, **kwargs): def _prepare_test( - batch_size: int -) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]: + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=input_tensor.dtype) logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) - return input_tensor, fake_logits, logits_processor, model_runner + return input_tensor, fake_logits, logits_processor RANDOM_SEEDS = list(range(128)) @@ -59,8 +53,7 @@ def test_logits_processors(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, logits_processor, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, logits_processor = _prepare_test(batch_size) # This sample logits processor gives infinite score to the i-th token, # where i is the length of the input sequence. @@ -87,8 +80,8 @@ def pick_ith(token_ids, logits): seq_group_metadata_list, seq_lens, query_lens=seq_lens, - device=model_runner.device, - pin_memory=model_runner.pin_memory) + device=device, + pin_memory=is_pin_memory_available()) logits_processor_output = logits_processor( embedding=None, hidden_states=input_tensor, @@ -99,5 +92,3 @@ def pick_ith(token_ids, logits): fake_logits *= logits_processor.scale assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1], 1e-4) - - del model_runner From 0ee535b2945d042cbb1fc6e63fd3fddd94d491f2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 9 May 2024 09:04:59 -0700 Subject: [PATCH 009/124] [Misc] Set block size at initialization & Fix test_model_runner (#4705) --- tests/worker/test_model_runner.py | 90 +++++++++++-------------------- vllm/worker/cpu_model_runner.py | 21 ++++---- vllm/worker/cpu_worker.py | 1 + vllm/worker/model_runner.py | 54 ++++++++----------- vllm/worker/worker.py | 2 +- 5 files changed, 64 insertions(+), 104 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index e7975d0ef48b9..3e3d2e3f5c53d 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,27 +1,38 @@ import pytest import torch -from vllm.config import ModelConfig, SchedulerConfig from vllm.distributed.parallel_state import init_distributed_environment +from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import get_open_port from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size +def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: + engine_args = EngineArgs(model, *args, **kwargs) + engine_config = engine_args.create_engine_config() + model_runner = ModelRunner( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, + load_config=engine_config.load_config, + lora_config=engine_config.lora_config, + is_driver_worker=True, + ) + return model_runner + + @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_prompt(batch_size): - scheduler_config = SchedulerConfig(100000, - 100000, - 100000, - enable_chunked_prefill=False) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=scheduler_config, - device_config=None, - load_config=None, - lora_config=None) - model_runner.set_block_size(16) + model_runner = _create_model_runner( + "facebook/opt-125m", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + ) seq_lens = [] seq_group_metadata_list = [] @@ -123,27 +134,15 @@ def test_prepare_prompt(batch_size): @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_decode_cuda_graph(batch_size): - model_config = ModelConfig( + model_runner = _create_model_runner( "facebook/opt-125m", - "facebook/opt-125m", - tokenizer_mode="auto", - trust_remote_code=False, seed=0, dtype="float16", - revision=None, enforce_eager=False, + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, ) - scheduler_config = SchedulerConfig(100000, - 100000, - 100000, - enable_chunked_prefill=False) - model_runner = ModelRunner(model_config=model_config, - parallel_config=None, - scheduler_config=scheduler_config, - device_config=None, - load_config=None, - lora_config=None) - model_runner.set_block_size(16) seq_lens = [] seq_group_metadata_list = [] @@ -214,23 +213,12 @@ def test_prepare_decode_cuda_graph(batch_size): def test_empty_seq_group(): """Verify prepare prompt and decode returns empty output.""" - model_config = ModelConfig( - "facebook/opt-125m", + model_runner = _create_model_runner( "facebook/opt-125m", - tokenizer_mode="auto", - trust_remote_code=False, seed=0, dtype="float16", - revision=None, enforce_eager=False, ) - model_runner = ModelRunner(model_config=model_config, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) - model_runner.set_block_size(16) seq_group_metadata_list = [] input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( model_runner._prepare_decode(seq_group_metadata_list)) @@ -260,29 +248,15 @@ def distributed_init(): @pytest.mark.parametrize("batch_size", list(range(2, 128))) @pytest.mark.parametrize("enforce_eager", [True, False]) def test_hybrid_batches(batch_size, enforce_eager, distributed_init): - - model_config = ModelConfig( - "facebook/opt-125m", + model_runner = _create_model_runner( "facebook/opt-125m", - tokenizer_mode="auto", - trust_remote_code=False, seed=0, dtype="float16", - revision=None, enforce_eager=enforce_eager, + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=True, ) - scheduler_config = SchedulerConfig(100000, - 100000, - 100000, - enable_chunked_prefill=True) - model_runner = ModelRunner(model_config=model_config, - parallel_config=None, - scheduler_config=scheduler_config, - device_config=None, - load_config=None, - lora_config=None, - is_driver_worker=True) - model_runner.set_block_size(16) # Add prefill requests. seq_lens = [] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 193b021b7a11e..6c8b1685dadcf 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -4,8 +4,9 @@ from torch import nn from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -26,6 +27,7 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], @@ -39,27 +41,22 @@ def __init__( self.scheduler_config = scheduler_config # Currently, CPU worker doesn't support chunked prefill. assert self.scheduler_config.chunked_prefill_enabled is False + self.device_config = device_config + self.cache_config = cache_config self.lora_config = lora_config self.vision_language_config = vision_language_config self.load_config = load_config self.is_driver_worker = is_driver_worker - # model_config can be None in tests/samplers/test_sampler.py. - # FIXME(woosuk): This is a hack to make the tests work. Refactor this. - self.sliding_window = (model_config.get_sliding_window() - if model_config is not None else None) - self.device_config = (device_config - if device_config is not None else DeviceConfig()) self.device = self.device_config.device self.kv_cache_dtype = kv_cache_dtype - - self.attn_backend = get_attn_backend( - self.model_config.dtype if model_config is not None else None) + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.attn_backend = get_attn_backend(self.model_config.dtype) # Lazy initialization. self.model: nn.Module # Set after init_Model - self.block_size: int # Set after initial profiling. def load_model(self) -> None: self.model = get_model( diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index e1ef500ac07b8..5e4ae564cb57e 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -151,6 +151,7 @@ def __init__( parallel_config, scheduler_config, device_config, + cache_config, load_config=self.load_config, lora_config=self.lora_config, vision_language_config=self.vision_language_config, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 46c6730645c1b..b5e582116297c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,8 +9,9 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) -from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce from vllm.distributed.device_communicators import (custom_all_reduce, pynccl_utils) @@ -106,6 +107,7 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", @@ -115,48 +117,40 @@ def __init__( self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config self.lora_config = lora_config self.load_config = load_config self.is_driver_worker = is_driver_worker + self.vision_language_config = vision_language_config - # model_config can be None in tests/samplers/test_sampler.py. - # FIXME(woosuk): This is a hack to make the tests work. Refactor this. - self.sliding_window = (model_config.get_sliding_window() - if model_config is not None else None) - self.device_config = (device_config - if device_config is not None else DeviceConfig()) self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() - # Set after load_model. - self.lora_manager: LRUCacheWorkerLoRAManager = None - + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. - - self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture - if self.model_config is not None else 0) - - self.pin_memory = is_pin_memory_available() - self.kv_cache_dtype = kv_cache_dtype - self.vision_language_config = vision_language_config - - self.attn_backend = get_attn_backend( - self.model_config.dtype if model_config is not None else None) - - # Lazy initialization - self.model: torch.nn.Module # Set after load_model - self.block_size: int # Set after initial profiling. # When using CUDA graph, the input block tables must be padded to # max_seq_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table # in numpy and only copy the actual input content at every iteration. # The shape of the cached block table will be # (max batch size to capture, max context len to capture / block size). - self.graph_block_tables: torch.Tensor # Set after initial profiling. + self.graph_block_tables = np.zeros( + (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), + dtype=np.int32) + self.attn_backend = get_attn_backend(self.model_config.dtype) + # Lazy initialization + self.model: torch.nn.Module # Set after load_model # Set if the backend is flashinfer. self.flashinfer_workspace_buffer: torch.Tensor + # Set after load_model. + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None def load_model(self) -> None: with CudaMemoryProfiler() as m: @@ -211,13 +205,6 @@ def load_model(self) -> None: "but the KV cache data type is not FP8. " "KV cache scaling factors will not be used.") - def set_block_size(self, block_size: int) -> None: - self.block_size = block_size - - self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), - dtype=np.int32) - def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size @@ -835,6 +822,7 @@ def profile_run(self) -> None: dummy_lora_requests = [] dummy_lora_requests_per_seq = [] if self.lora_config: + assert self.lora_manager is not None with self.lora_manager.dummy_lora_cache(): for idx in range(self.lora_config.max_loras): lora_id = idx + 1 diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 313bcf25d8870..43f6b2b443b70 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -75,6 +75,7 @@ def __init__( parallel_config, scheduler_config, device_config, + cache_config, load_config=load_config, lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, @@ -184,7 +185,6 @@ def _init_cache_engine(self): self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) self.gpu_cache = self.cache_engine.gpu_cache - self.model_runner.set_block_size(self.cache_engine.block_size) def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: From ff5abcd7463211251bcf19916ae3b45d762f48d4 Mon Sep 17 00:00:00 2001 From: kliuae <17350011+kliuae@users.noreply.github.com> Date: Fri, 10 May 2024 00:19:50 +0800 Subject: [PATCH 010/124] [ROCm] Add support for Punica kernels on AMD GPUs (#3140) Co-authored-by: miloice --- CMakeLists.txt | 16 +- Dockerfile.rocm | 3 + csrc/cuda_compat.h | 6 + csrc/punica/bgmv/bgmv_impl.cuh | 154 +++++++++++++++++++ csrc/punica/bgmv/vec_dtypes.cuh | 5 +- csrc/punica/{punica_ops.cc => punica_ops.cu} | 17 +- csrc/punica/punica_ops.h | 11 ++ csrc/punica/punica_pybind.cpp | 13 ++ csrc/punica/type_convert.h | 82 ++++++++++ setup.py | 6 +- 10 files changed, 287 insertions(+), 26 deletions(-) rename csrc/punica/{punica_ops.cc => punica_ops.cu} (98%) create mode 100644 csrc/punica/punica_ops.h create mode 100644 csrc/punica/punica_pybind.cpp create mode 100644 csrc/punica/type_convert.h diff --git a/CMakeLists.txt b/CMakeLists.txt index f817f3382c5e1..47629f036fb09 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -219,7 +219,8 @@ set(VLLM_PUNICA_EXT_SRC "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu" "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" - "csrc/punica/punica_ops.cc") + "csrc/punica/punica_ops.cu" + "csrc/punica/punica_pybind.cpp") # # Copy GPU compilation flags+update for punica @@ -243,6 +244,9 @@ if (${VLLM_GPU_LANG} STREQUAL "CUDA") endif() endforeach() message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") +elseif(${VLLM_GPU_LANG} STREQUAL "HIP") + set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES}) + message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") endif() if (VLLM_PUNICA_GPU_ARCHES) @@ -277,11 +281,6 @@ add_custom_target(default) if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling C extension.") add_dependencies(default _C) -endif() - -if(VLLM_GPU_LANG STREQUAL "CUDA") - message(STATUS "Enabling moe extension.") - add_dependencies(default _moe_C) # Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or # VLLM_INSTALL_PUNICA_KERNELS is set in the environment and @@ -292,3 +291,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") add_dependencies(default _punica_C) endif() endif() + +if(VLLM_GPU_LANG STREQUAL "CUDA") + message(STATUS "Enabling moe extension.") + add_dependencies(default _moe_C) +endif() diff --git a/Dockerfile.rocm b/Dockerfile.rocm index d04bb9915e2ab..eefad79e79d83 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -94,6 +94,9 @@ COPY . . RUN python3 -m pip install --upgrade pip numba +# make sure punica kernels are built (for LoRA) +ENV VLLM_INSTALL_PUNICA_KERNELS=1 + RUN --mount=type=cache,target=/root/.cache/pip \ pip install -U -r requirements-rocm.txt \ && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index c711d8d1b24b9..1ebb2e74a82fc 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -28,6 +28,12 @@ #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) #endif +#ifndef USE_ROCM + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta) +#else + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) +#endif + #ifndef USE_ROCM #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh index dad8805c750cb..8a3b8403b4a6f 100644 --- a/csrc/punica/bgmv/bgmv_impl.cuh +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -1,8 +1,14 @@ #pragma once #include +#ifndef USE_ROCM #include +#else +#include +#endif +#ifndef USE_ROCM #include +#endif #include #include #include @@ -11,6 +17,24 @@ namespace cg = cooperative_groups; +#ifdef USE_ROCM +template +__host__ __device__ +inline void* memcpy_blocking(void *dst, const void *src) { + // Does not handle the case of long datatypes + char *d = reinterpret_cast(dst); + const char *s = reinterpret_cast(src); + size_t i = 0; +#pragma unroll + for (i = 0; i < len; ++i) { + d[i] = s[i]; + } + return dst; +} +#endif + +#ifndef USE_ROCM + // nthrs = (32, 4) template +__global__ void +bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t num_layers, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + if (idx < 0) { + return; + } + + size_t j = blockIdx.x; + constexpr size_t tile_size = tx * ty * vec_size; + constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size; + __shared__ float y_warpwise[ty]; + + float y = 0; + vec_t x_vec; + vec_t w_vec; + size_t tile_idx; + +#pragma unroll + for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { + x_vec.load(X + (batch_idx * feat_in) + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W + (idx * feat_out + j) * feat_in + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size); + } + + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += VLLM_SHFL_DOWN_SYNC(sum, offset); + } + + __syncthreads(); + + if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { + y += sum; + } + } + + if (threadIdx.x == 0) { + y_warpwise[threadIdx.y] = y; + } + __syncthreads(); + + float y_write = 0.f; +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y_write += y_warpwise[i]; + } + + // write Y; + if (threadIdx.x == 0 && threadIdx.y == 0) { + size_t y_idx = batch_idx * full_y_size + y_offset + j; + Y[y_idx] = vllm_add(Y[y_idx], convert_type(y_write)); + } +} + +#endif + // nthrs = (2, 16, 4) template @@ -172,7 +271,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, float sum = 0.f; #pragma unroll for (size_t i = 0; i < vec_size; ++i) { +#ifndef USE_ROCM sum += float(w_vec[i]) * float(x_vec[i]) * scale; +#else + sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; +#endif } cg::thread_block_tile g = cg::tiled_partition(block); @@ -183,8 +286,14 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, sum = g.shfl(sum, 0); if (threadIdx.x == 0) { +#ifndef USE_ROCM Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + threadIdx.z * ty + threadIdx.y] += static_cast(sum); +#else + size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + + threadIdx.z * ty + threadIdx.y; + Y[y_idx] = vllm_add(Y[y_idx], convert_type(sum)); +#endif } } @@ -236,6 +345,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, scale); } } else { +#ifndef USE_ROCM static_assert(feat_in % (vec_size * 32) == 0 || feat_in % (vec_size * 16) == 0 || feat_in % (vec_size * 8) == 0); @@ -279,6 +389,50 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, full_y_size, num_layers, layer_idx, scale); } +#else + constexpr size_t rocm_warp_size = warpSize; + +#define CHECK_INPUT_TILEABLE_BY(vec_size_) \ + feat_in % (rocm_warp_size * vec_size_) == 0 + +#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \ + if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \ + constexpr size_t vec_size_shrink = vec_size_; \ + constexpr int tx = tx_; \ + constexpr int ty = ty_; \ + dim3 nblks(feat_out, batch_size); \ + dim3 nthrs(tx, ty); \ + bgmv_shrink_kernel \ + <<>>(Y, X, W, indicies, y_offset, \ + full_y_size, num_layers, layer_idx, \ + scale); \ + } + + static_assert(CHECK_INPUT_TILEABLE_BY(32) || + CHECK_INPUT_TILEABLE_BY(16) || + CHECK_INPUT_TILEABLE_BY( 8) || + CHECK_INPUT_TILEABLE_BY( 4) || + CHECK_INPUT_TILEABLE_BY( 2) || + CHECK_INPUT_TILEABLE_BY( 1)); + + LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1) + +#undef CHECK_INPUT_TILEABLE_BY +#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM +#endif } } diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh index cf00d869cf635..2738892e6dc4a 100644 --- a/csrc/punica/bgmv/vec_dtypes.cuh +++ b/csrc/punica/bgmv/vec_dtypes.cuh @@ -1,8 +1,6 @@ #ifndef VEC_DTYPES_CUH_ #define VEC_DTYPES_CUH_ -#include -#include #ifdef FLASHINFER_USE_FP8 #include #endif @@ -10,6 +8,9 @@ #include +#include "../type_convert.h" +#include "../../cuda_compat.h" + #define FLASHINFER_INLINE \ inline __attribute__((always_inline)) __device__ __host__ diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cu similarity index 98% rename from csrc/punica/punica_ops.cc rename to csrc/punica/punica_ops.cu index 8797fde85744a..61de3b37937cc 100644 --- a/csrc/punica/punica_ops.cc +++ b/csrc/punica/punica_ops.cu @@ -1,12 +1,11 @@ -#include -#include #include #include #include +#include "type_convert.h" +#include "../cuda_compat.h" #include "bgmv/bgmv_config.h" -namespace { //====== utils ====== @@ -568,15 +567,3 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); } - -} // namespace - -//====== pybind ====== - -#define DEFINE_pybind(name) m.def(#name, &name, #name); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); - m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, - "dispatch_bgmv_low_level"); -} diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h new file mode 100644 index 0000000000000..937e2d1d25d4a --- /dev/null +++ b/csrc/punica/punica_ops.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, float scale); + +void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, + float scale, int64_t h_in, int64_t h_out, + int64_t y_offset); diff --git a/csrc/punica/punica_pybind.cpp b/csrc/punica/punica_pybind.cpp new file mode 100644 index 0000000000000..9490ad59cdd5f --- /dev/null +++ b/csrc/punica/punica_pybind.cpp @@ -0,0 +1,13 @@ +#include + +#include "punica_ops.h" + +//====== pybind ====== + +#define DEFINE_pybind(name) m.def(#name, &name, #name); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); + m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, + "dispatch_bgmv_low_level"); +} diff --git a/csrc/punica/type_convert.h b/csrc/punica/type_convert.h new file mode 100644 index 0000000000000..dff7ce49283d7 --- /dev/null +++ b/csrc/punica/type_convert.h @@ -0,0 +1,82 @@ +#ifndef CSRC__PUNICA__TYPE_CONVERT_H__ +#define CSRC__PUNICA__TYPE_CONVERT_H__ + +#ifndef USE_ROCM + +#include +#include + +#else + +#include +#include + +#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__ + +typedef __half nv_half; +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) { + return __hip_bfloat162{val, val}; +} + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) { + return __hip_bfloat162{vall, valr}; +} + +template +__TYPE_CONVERT__HOST_DEVICE__ +inline T_dst convert_type(T_src val) { + return static_cast(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline float convert_type<__half, float>(__half val) { + return __half2float(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __half convert_type(float val) { + return __float2half(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) { + return __bfloat162float(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 convert_type(float val) { + return __float2bfloat16(val); +} + +template +__TYPE_CONVERT__HOST_DEVICE__ +inline T vllm_add(T a, T b) { + return a + b; +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __half vllm_add<__half>(__half a, __half b) { + return __hadd(a, b); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) { + return __hadd(a, b); +} + +#undef __TYPE_CONVERT__HOST_DEVICE__ + +#endif // USE_ROCM + +#endif // CSRC__PUNICA__TYPE_CONVERT_H__ diff --git a/setup.py b/setup.py index d9ba96b82329a..0dc8818b44a9e 100644 --- a/setup.py +++ b/setup.py @@ -385,12 +385,12 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) - if _install_punica(): - ext_modules.append(CMakeExtension(name="vllm._punica_C")) - if not _is_neuron(): ext_modules.append(CMakeExtension(name="vllm._C")) + if _install_punica(): + ext_modules.append(CMakeExtension(name="vllm._punica_C")) + package_data = { "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] } From a3c124570a66f746ba09faabe2e14851386b395a Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 10 May 2024 00:53:14 +0800 Subject: [PATCH 011/124] [Bugfix] Fix CLI arguments in OpenAI server docs (#4709) --- docs/source/serving/openai_compatible_server.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index c157d8ba998da..15a8761eb5738 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -108,5 +108,5 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) ```{argparse} :module: vllm.entrypoints.openai.cli_args :func: make_arg_parser -:prog: vllm-openai-server +:prog: -m vllm.entrypoints.openai.api_server ``` \ No newline at end of file From cea64430f615ff90c67ff8375ec86562913c5500 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Thu, 9 May 2024 11:10:13 -0600 Subject: [PATCH 012/124] [Bugfix] Update grafana.json (#4711) --- examples/production_monitoring/grafana.json | 432 +++++++++++--------- 1 file changed, 239 insertions(+), 193 deletions(-) diff --git a/examples/production_monitoring/grafana.json b/examples/production_monitoring/grafana.json index 5e9bd5bd03869..273f7f5ac42cf 100644 --- a/examples/production_monitoring/grafana.json +++ b/examples/production_monitoring/grafana.json @@ -1,4 +1,41 @@ { + "__inputs": [ + { + "name": "DS_PROMETHEUS", + "label": "prometheus", + "description": "", + "type": "datasource", + "pluginId": "prometheus", + "pluginName": "Prometheus" + } + ], + "__elements": {}, + "__requires": [ + { + "type": "grafana", + "id": "grafana", + "name": "Grafana", + "version": "10.4.2" + }, + { + "type": "panel", + "id": "heatmap", + "name": "Heatmap", + "version": "" + }, + { + "type": "datasource", + "id": "prometheus", + "name": "Prometheus", + "version": "1.0.0" + }, + { + "type": "panel", + "id": "timeseries", + "name": "Time series", + "version": "" + } + ], "annotations": { "list": [ { @@ -25,14 +62,14 @@ "editable": true, "fiscalYearStartMonth": 0, "graphTooltip": 0, - "id": 29, + "id": null, "links": [], "liveNow": false, "panels": [ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "End to end request latency measured in seconds.", "fieldConfig": { @@ -41,6 +78,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -54,6 +92,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -111,7 +150,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -127,7 +166,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -144,7 +183,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -161,7 +200,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -178,7 +217,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "rate(vllm:e2e_request_latency_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:e2e_request_latency_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", @@ -195,7 +234,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "Number of tokens processed per second", "fieldConfig": { @@ -204,6 +243,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -217,6 +257,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -273,7 +314,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -289,7 +330,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -310,7 +351,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "Inter token latency in seconds.", "fieldConfig": { @@ -319,6 +360,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -332,6 +374,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -389,7 +432,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -405,7 +448,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -422,7 +465,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -439,7 +482,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -456,7 +499,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "rate(vllm:time_per_output_token_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:time_per_output_token_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", @@ -473,7 +516,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "Number of requests in RUNNING, WAITING, and SWAPPED state", "fieldConfig": { @@ -482,6 +525,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -495,6 +539,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -552,7 +597,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -568,7 +613,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -585,7 +630,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -606,7 +651,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "P50, P90, P95, and P99 TTFT latency in seconds.", "fieldConfig": { @@ -615,6 +660,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -628,6 +674,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -685,7 +732,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -702,7 +749,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -718,7 +765,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -735,7 +782,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -752,7 +799,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "rate(vllm:time_to_first_token_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:time_to_first_token_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", @@ -769,7 +816,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "Percentage of used cache blocks by vLLM.", "fieldConfig": { @@ -778,6 +825,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -791,6 +839,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -848,7 +897,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "vllm:gpu_cache_usage_perc{model_name=\"$model_name\"}", @@ -860,7 +909,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "vllm:cpu_cache_usage_perc{model_name=\"$model_name\"}", @@ -875,229 +924,232 @@ "type": "timeseries" }, { - "type": "heatmap", - "title": "Request Prompt Length", + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, "description": "Heatmap of request prompt length", + "fieldConfig": { + "defaults": { + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "scaleDistribution": { + "type": "linear" + } + } + }, + "overrides": [] + }, "gridPos": { - "x": 0, - "y": 24, + "h": 8, "w": 12, - "h": 8 - }, - "datasource": { - "uid": "prometheus", - "type": "prometheus" + "x": 0, + "y": 24 }, "id": 12, - "targets": [ - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "refId": "A", - "expr": "sum by(le) (increase(vllm:request_prompt_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", - "range": true, - "instant": false, - "editorMode": "builder", - "legendFormat": "{{le}}", - "useBackend": false, - "disableTextWrap": false, - "fullMetaSearch": false, - "includeNullMetadata": true, - "format": "heatmap" - } - ], "options": { "calculate": false, - "yAxis": { - "axisPlacement": "left", - "reverse": false, - "unit": "none", - "axisLabel": "Prompt Length" - }, - "rowsFrame": { - "layout": "auto", - "value": "Request count" + "cellGap": 1, + "cellValues": { + "unit": "none" }, "color": { - "mode": "scheme", + "exponent": 0.5, "fill": "dark-orange", + "min": 0, + "mode": "scheme", + "reverse": false, "scale": "exponential", - "exponent": 0.5, "scheme": "Spectral", - "steps": 64, - "reverse": false, - "min": 0 + "steps": 64 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" }, - "cellGap": 1, "filterValues": { "le": 1e-9 }, - "tooltip": { - "show": true, - "yHistogram": true - }, "legend": { "show": true }, - "exemplars": { - "color": "rgba(255,0,255,0.7)" + "rowsFrame": { + "layout": "auto", + "value": "Request count" }, - "cellValues": { + "tooltip": { + "mode": "single", + "showColorScale": false, + "yHistogram": true + }, + "yAxis": { + "axisLabel": "Prompt Length", + "axisPlacement": "left", + "reverse": false, "unit": "none" } }, + "pluginVersion": "10.4.2", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "sum by(le) (increase(vllm:request_prompt_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", + "format": "heatmap", + "fullMetaSearch": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "{{le}}", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Request Prompt Length", + "type": "heatmap" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Heatmap of request generation length", "fieldConfig": { "defaults": { "custom": { - "scaleDistribution": { - "type": "linear" - }, "hideFrom": { + "legend": false, "tooltip": false, - "viz": false, - "legend": false + "viz": false + }, + "scaleDistribution": { + "type": "linear" } } }, "overrides": [] }, - "pluginVersion": "10.2.0" - }, - { - "datasource": { - "uid": "prometheus", - "type": "prometheus" - }, - "type": "heatmap", - "title": "Request Generation Length", - "description": "Heatmap of request generation length", "gridPos": { - "x": 12, - "y": 24, + "h": 8, "w": 12, - "h": 8 + "x": 12, + "y": 24 }, "id": 13, - "targets": [ - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "refId": "A", - "expr": "sum by(le) (increase(vllm:request_generation_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", - "range": true, - "instant": false, - "editorMode": "builder", - "legendFormat": "{{le}}", - "useBackend": false, - "disableTextWrap": false, - "fullMetaSearch": false, - "includeNullMetadata": true, - "format": "heatmap" - } - ], "options": { "calculate": false, - "yAxis": { - "axisPlacement": "left", - "reverse": false, - "unit": "none", - "axisLabel": "Generation Length" - }, - "rowsFrame": { - "layout": "auto", - "value": "Request count" + "cellGap": 1, + "cellValues": { + "unit": "none" }, "color": { - "mode": "scheme", + "exponent": 0.5, "fill": "dark-orange", + "min": 0, + "mode": "scheme", + "reverse": false, "scale": "exponential", - "exponent": 0.5, "scheme": "Spectral", - "steps": 64, - "reverse": false, - "min": 0 + "steps": 64 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" }, - "cellGap": 1, "filterValues": { "le": 1e-9 }, - "tooltip": { - "show": true, - "yHistogram": true - }, "legend": { "show": true }, - "exemplars": { - "color": "rgba(255,0,255,0.7)" + "rowsFrame": { + "layout": "auto", + "value": "Request count" }, - "cellValues": { + "tooltip": { + "mode": "single", + "showColorScale": false, + "yHistogram": true + }, + "yAxis": { + "axisLabel": "Generation Length", + "axisPlacement": "left", + "reverse": false, "unit": "none" } }, - "fieldConfig": { - "defaults": { - "custom": { - "scaleDistribution": { - "type": "linear" - }, - "hideFrom": { - "tooltip": false, - "viz": false, - "legend": false - } - } - }, - "overrides": [] - }, - "pluginVersion": "10.2.0" + "pluginVersion": "10.4.2", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "sum by(le) (increase(vllm:request_generation_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", + "format": "heatmap", + "fullMetaSearch": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "{{le}}", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Request Generation Length", + "type": "heatmap" }, { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, + "description": "Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.", "fieldConfig": { "defaults": { + "color": { + "mode": "palette-classic" + }, "custom": { - "drawStyle": "line", - "lineInterpolation": "linear", + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", "barAlignment": 0, - "lineWidth": 1, + "drawStyle": "line", "fillOpacity": 0, "gradientMode": "none", - "spanNulls": false, + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, "insertNulls": false, - "showPoints": "auto", + "lineInterpolation": "linear", + "lineWidth": 1, "pointSize": 5, - "stacking": { - "mode": "none", - "group": "A" - }, - "axisPlacement": "auto", - "axisLabel": "", - "axisColorMode": "text", - "axisBorderShow": false, "scaleDistribution": { "type": "linear" }, - "axisCenteredZero": false, - "hideFrom": { - "tooltip": false, - "viz": false, - "legend": false + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" }, "thresholdsStyle": { "mode": "off" } }, - "color": { - "mode": "palette-classic" - }, "mappings": [], "thresholds": { "mode": "absolute", @@ -1123,22 +1175,22 @@ }, "id": 11, "options": { - "tooltip": { - "mode": "single", - "sort": "none" - }, "legend": { - "showLegend": true, + "calcs": [], "displayMode": "list", "placement": "bottom", - "calcs": [] + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" } }, "targets": [ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -1154,25 +1206,19 @@ } ], "title": "Finish Reason", - "description": "Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.", "type": "timeseries" } ], "refresh": "", - "schemaVersion": 37, - "style": "dark", + "schemaVersion": 39, "tags": [], "templating": { "list": [ { - "current": { - "selected": false, - "text": "vllm", - "value": "vllm" - }, + "current": {}, "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "definition": "label_values(model_name)", "hide": 0, @@ -1201,6 +1247,6 @@ "timezone": "", "title": "vLLM", "uid": "b281712d-8bff-41ef-9f3f-71ad43c05e9b", - "version": 2, + "version": 1, "weekStart": "" } From be0c5180ac0832a0b285d0845d458798bb3f0f4f Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 9 May 2024 14:36:25 -0400 Subject: [PATCH 013/124] [Bugfix] Add logs for all model dtype casting (#4717) --- vllm/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index a2cb9b32c65fc..275814d72e6c3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1063,6 +1063,7 @@ def _get_and_verify_dtype( if config_dtype == torch.float32: # Following the common practice, we use float16 for float32 # models. + logger.info("Casting torch.float32 to torch.float16.") torch_dtype = torch.float16 else: torch_dtype = config_dtype @@ -1087,9 +1088,11 @@ def _get_and_verify_dtype( if torch_dtype != config_dtype: if torch_dtype == torch.float32: # Upcasting to float32 is allowed. + logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) pass elif config_dtype == torch.float32: # Downcasting from float32 to float16 or bfloat16 is allowed. + logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) pass else: # Casting between float16 and bfloat16 is allowed with a warning. From ebce310b7433e050086f52ca48571807df467f50 Mon Sep 17 00:00:00 2001 From: Hao Zhang <152229491+sfc-gh-hazhang@users.noreply.github.com> Date: Thu, 9 May 2024 15:37:14 -0700 Subject: [PATCH 014/124] [Model] Snowflake arctic model implementation (#4652) Co-authored-by: Dash Desai <1723932+iamontheinet@users.noreply.github.com> Co-authored-by: Aurick Qiao Co-authored-by: Aurick Qiao Co-authored-by: Aurick Qiao Co-authored-by: Cody Yu --- examples/offline_inference_arctic.py | 26 + .../layers/fused_moe/__init__.py | 4 +- .../layers/fused_moe/fused_moe.py | 137 +++-- .../layers/quantization/__init__.py | 3 + .../layers/quantization/deepspeedfp.py | 194 +++++++ vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/arctic.py | 521 ++++++++++++++++++ vllm/transformers_utils/configs/arctic.py | 204 +++++++ 8 files changed, 1042 insertions(+), 48 deletions(-) create mode 100644 examples/offline_inference_arctic.py create mode 100644 vllm/model_executor/layers/quantization/deepspeedfp.py create mode 100644 vllm/model_executor/models/arctic.py create mode 100644 vllm/transformers_utils/configs/arctic.py diff --git a/examples/offline_inference_arctic.py b/examples/offline_inference_arctic.py new file mode 100644 index 0000000000000..1fec3c99eb47c --- /dev/null +++ b/examples/offline_inference_arctic.py @@ -0,0 +1,26 @@ +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM(model="snowflake/snowflake-arctic-instruct", + quantization="deepspeedfp", + tensor_parallel_size=8, + trust_remote_code=True) +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. + +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 496d69c89c62b..2926c7d1c8a76 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,7 +1,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_moe, get_config_file_name) + fused_experts, fused_moe, fused_topk, get_config_file_name) __all__ = [ "fused_moe", + "fused_topk", + "fused_experts", "get_config_file_name", ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3cb0419404625..bb7938b3715be 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -308,60 +308,16 @@ def get_moe_configs(E: int, N: int, return None -def fused_moe( +def fused_topk( hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, - inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. +): assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] + M, _ = hidden_states.shape - E, N, _ = w1.shape if is_hip(): # The MoE kernels are not yet supported on ROCm. @@ -393,6 +349,33 @@ def fused_moe( del token_expert_indicies # Not used. Will be used in the future. if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids + + +def fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None): + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + M, _ = hidden_states.shape + E, N, _ = w1.shape if override_config: config = override_config @@ -477,3 +460,63 @@ def fused_moe( out=hidden_states) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + return fused_experts(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + override_config=override_config, + use_fp8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 1c652e347d4ad..5798bc359dcf2 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,6 +4,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.deepspeedfp import ( + DeepSpeedFPConfig) from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( @@ -19,6 +21,7 @@ "squeezellm": SqueezeLLMConfig, "gptq_marlin": GPTQMarlinConfig, "marlin": MarlinConfig, + "deepspeedfp": DeepSpeedFPConfig } diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py new file mode 100644 index 0000000000000..31cdffbcf0ab9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -0,0 +1,194 @@ +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + + +class DeepSpeedFPConfig(QuantizationConfig): + """Config for DeepSpeed FP quantizer. It supports fp6 and fp8. + + Args: + weight_bits: the target quantization bits, 6 or 8. + group_size: group size for quantizaiton, default to 128. + """ + + def __init__( + self, + weight_bits: int = 8, + group_size: int = 512, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.valid_types = [torch.bfloat16, torch.float16] + + if self.weight_bits not in (6, 8): + raise ValueError( + "Currently, only 6-bit or 8-bit weight quantization are " + f"supported for DeepSpeed FP quantizaiton, but got " + f"{self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), " + f"group_size={self.group_size}") + + @classmethod + def get_name(cls) -> str: + return "DeepSpeedFP" + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits=weight_bits, group_size=group_size) + + def get_linear_method(self) -> "DeepSpeedFPLinearMethod": + return DeepSpeedFPLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quant_config.json", + "quantize_config.json", + ] + + def get_quant_method( + self, + layer: torch.nn.Module) -> Optional["DeepSpeedFPLinearMethod"]: + if isinstance(layer, LinearBase): + return DeepSpeedFPLinearMethod(self) + return None + + +class DeepSpeedFPLinearMethod(LinearMethodBase): + """Linear method for DeepSpeedFP quantizer. + + Args: + quant_config: the DeepSpeedFP quantization config. + """ + + def __init__(self, quant_config: DeepSpeedFPConfig): + self.quant_config = quant_config + self.weight = None + + def create_weights(self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_loader=None, + **extra_weight_attrs): + del output_size + del input_size + output_size_per_partition = sum(output_partition_sizes) + weight = DeepSpeedFPParameter( + torch.Size((output_size_per_partition, input_size_per_partition)), + params_dtype=params_dtype, + quant_config=self.quant_config, + ) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + }) + layer.register_parameter("weight", weight) + + def quant_weight_loader(param, loaded_weight, *args, **kwargs): + # Calls the original weight loader (if any), quantizes the result, + # and then loads the quantized parameter. + if weight_loader is not None: + orig_param_data = param.data + param.data = param.ds_dequantize() + weight_loader(param, loaded_weight, *args, **kwargs) + param.data, loaded_weight = orig_param_data, param.data + param.ds_quantize_(loaded_weight.cuda()) + + extra_weight_attrs["weight_loader"] = quant_weight_loader + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight + y = weight.ds_dequantize() + return F.linear(x, y, bias) + + +class DeepSpeedFPParameter(nn.Parameter): + """ + DeepSpeedFP quantized parameter class that implements fp8/fp6 + quantization deepspeed. Weights are stored in quantized form on + GPUs, and can be dequantized on-the-fly when needed by the model. + """ + + def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype, + quant_config: DeepSpeedFPConfig): + try: + import deepspeed + if deepspeed.__version__ < "0.14.2": + raise ImportError("deepspeed version is wrong. Please " + "install deepspeed>=0.14.2.") + from deepspeed.ops.fp_quantizer import FP_Quantize + except ImportError as err: + raise ImportError("Please install deepspeed>=0.14.2 via " + "`pip install deepspeed>=0.14.2` to use " + "deepspeedfp quantizer.") from err + data = torch.empty(( + orig_shape.numel() // quant_config.group_size, + quant_config.group_size * quant_config.weight_bits // 8 + 4, + ), + dtype=torch.int8) + self = torch.Tensor._make_subclass(cls, data, data.requires_grad) + self.orig_shape = orig_shape + self.quant_config = quant_config + self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size) + self.fp_quantizer.orig_shape = orig_shape + self.fp_quantizer.orig_dtype = params_dtype + return self + + def ds_quantize_(self, tensor: torch.Tensor): + assert tensor.device.type == "cuda" and tensor.dtype != torch.int8 + return self.data.copy_( + self.fp_quantizer.quantize( + tensor.data, + q_bits=self.quant_config.weight_bits, + )) + + def ds_dequantize(self, fp_out=None) -> torch.Tensor: + """ + Return a tensor containing the dequantized weights of this parameter. + """ + assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + return self.fp_quantizer.dequantize( + self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits) + + def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor: + """ + Return a tensor where only the weights at `indices` are dequantized + (to save HBM -> SRAM bandwidth). + """ + assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + return self.fp_quantizer.selective_dequantize( + self.data, + indices, + fp_out=fp_out, + q_bits=self.quant_config.weight_bits) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index c5cdc059473b3..d5263b500fe0f 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -54,6 +54,7 @@ "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), + "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), } diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py new file mode 100644 index 0000000000000..796cef7c4a735 --- /dev/null +++ b/vllm/model_executor/models/arctic.py @@ -0,0 +1,521 @@ +"""Inference-only Snowflake Arctic model.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.deepspeedfp import ( + DeepSpeedFPConfig, DeepSpeedFPParameter) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import SamplerOutput +from vllm.transformers_utils.configs.arctic import ArcticConfig + +logger = init_logger(__name__) + + +class ArcticMLP(nn.Module): + + def __init__(self, + config: ArcticConfig, + layer_id: int, + expert_id: int = -1, + is_residual_mlp: bool = False, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True): + super(ArcticMLP, self).__init__() + self.hidden_size = config.hidden_size + self.expert_id = expert_id + self.layer_id = layer_id + + self.ffn_dim = config.intermediate_size if not is_residual_mlp \ + else self.hidden_size + + self.w13 = MergedColumnParallelLinear(self.hidden_size, + [self.ffn_dim] * 2, + bias=False, + quant_config=quant_config) + self.w2 = RowParallelLinear(self.ffn_dim, + self.hidden_size, + bias=False, + reduce_results=reduce_results, + quant_config=quant_config) + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, hidden_states): + gate_up, _ = self.w13(hidden_states) + hidden_states = self.act_fn(gate_up) + hidden_states, _ = self.w2(hidden_states) + return hidden_states + + +class ArcticMoE(nn.Module): + """ + Model-parallel implementation of Arctic MoE Layer. + """ + + def __init__(self, + config: ArcticConfig, + layer_id: int, + tp_size: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True): + super(ArcticMoE, self).__init__() + + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.hidden_size = config.hidden_size + self.num_experts = config.num_local_experts + self.layer_id = layer_id + self.top_k = config.num_experts_per_tok + self.intermediate_size = config.intermediate_size // self.tp_size + + self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0 + self.is_quant = isinstance(quant_config, DeepSpeedFPConfig) + self.reduce_results = reduce_results + # Some other parameters + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + if not self.is_moe_layer: + self.mlp = ArcticMLP(config, + layer_id=layer_id, + quant_config=quant_config, + reduce_results=reduce_results) + else: + self.gate = ReplicatedLinear(self.hidden_size, + self.num_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=quant_config) + if self.is_quant: + self.ws = DeepSpeedFPParameter( + torch.Size((self.num_experts, 2 * self.intermediate_size, + self.hidden_size)), + params_dtype=params_dtype, + quant_config=quant_config, + ) + self.w2s = DeepSpeedFPParameter( + torch.Size((self.num_experts, self.hidden_size, + self.intermediate_size)), + params_dtype=params_dtype, + quant_config=quant_config, + ) + else: + self.ws = nn.Parameter( + torch.empty(self.num_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype)) + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.ds_dequantize() if self.is_quant else param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + if self.is_quant: + param.ds_quantize_(param_data) + + def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + do_normalize = self.top_k > 1 + topk_weights, topk_ids = fused_topk(hidden_states, + router_logits, + self.top_k, + renormalize=do_normalize) + # topk_ids: (num_tokens, k) + if self.is_quant: + if 2 * num_tokens <= self.num_experts: + # If much fewer tokens than experts, use selective dequantize. + ws_dequantized = self.ws.ds_selective_dequantize( + topk_ids.flatten()) + w2s_dequantized = self.w2s.ds_selective_dequantize( + topk_ids.flatten()) + # We gathered the experts to the tokens so update the mapping. + topk_ids = torch.arange( + 0, + topk_ids.numel(), + device=topk_ids.device, + ).reshape(topk_ids.shape) + else: + ws_dequantized = self.ws.ds_dequantize() + w2s_dequantized = self.w2s.ds_dequantize() + + final_hidden_states = fused_experts( + hidden_states, + ws_dequantized if self.is_quant else self.ws, + w2s_dequantized if self.is_quant else self.w2s, + topk_weights, + topk_ids, + inplace=True) + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) + + def forward(self, hidden_states: torch.Tensor): + if self.is_moe_layer: + final_hidden_states = self.local_moe_fused(hidden_states) + else: + final_hidden_states = self.mlp(hidden_states) + return final_hidden_states + + +class ArcticAttention(nn.Module): + + def __init__( + self, + config: ArcticConfig, + layer_idx: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + reduce_results=True, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=int(self.rope_theta), + is_neox_style=True, + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class ArcticDecoderLayer(nn.Module): + + def __init__( + self, + config: ArcticConfig, + layer_idx: int, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0 + self.use_residual = config.use_residual and is_moe_layer + self.self_attn = ArcticAttention(config, + layer_idx, + quant_config=quant_config) + self.block_sparse_moe = ArcticMoE( + config, + layer_id=layer_idx, + quant_config=quant_config, + reduce_results=(not self.use_residual)) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + if self.use_residual: + self.residual_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.residual_mlp = ArcticMLP(config, + layer_id=layer_idx, + is_residual_mlp=True, + reduce_results=False) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual_input = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual_input + hidden_states + + residual_attn = hidden_states + if self.use_residual: + hidden_states = self.residual_layernorm(hidden_states) + hidden_states = self.residual_mlp(hidden_states) + residual_mlp = hidden_states + hidden_states = self.post_attention_layernorm(residual_input) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual_mlp + hidden_states + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + hidden_states = residual_attn + hidden_states + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual_attn + hidden_states + return hidden_states + + +class ArcticModel(nn.Module): + + def __init__( + self, + config: ArcticConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=self.vocab_size) + self.layers = nn.ModuleList([ + ArcticDecoderLayer(config, layer_idx, quant_config=quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + self._attn_implementation = config._attn_implementation + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer(positions, hidden_states, kv_caches[i], + attn_metadata) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class ArcticForCausalLM(nn.Module): + + def __init__(self, + config: ArcticConfig, + quant_config: Optional[QuantizationConfig] = None, + **kwargs) -> None: + super().__init__() + self.config = config + self.model = ArcticModel(config, quant_config) + self.vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.vocab_size, + config.hidden_size, + ) + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + self.unpadded_vocab_size = config.vocab_size + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + mlp_params_mapping = [] + expert_params_mapping = [] + num_layers = self.config.num_hidden_layers + + for layer in range(num_layers): + mlp_params_mapping.append( + (f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w1.weight", 0)) + mlp_params_mapping.append( + (f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w3.weight", 1)) + if layer % 2 == 0: + # MLP layers + mlp_params_mapping.append( + (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w1.weight", 0)) + mlp_params_mapping.append( + (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w3.weight", 1)) + else: + # MoE layers + for expert_id in range(self.config.num_local_experts): + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w1.weight", expert_id)) + expert_params_mapping.append( + ("w2s", f"experts.{expert_id}.w2.weight", expert_id)) + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w3.weight", expert_id)) + + params_dict = dict(self.named_parameters()) + + logger.info( + "It will take ~10 minutes loading from the 16-bit weights. " + "Alternatively, use the prequantized 8-bit weights of arctic " + "and set load-format to `sharded_state` will accelerate loading.") + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id in mlp_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id \ + in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/transformers_utils/configs/arctic.py b/vllm/transformers_utils/configs/arctic.py new file mode 100644 index 0000000000000..7780bf5e78d6d --- /dev/null +++ b/vllm/transformers_utils/configs/arctic.py @@ -0,0 +1,204 @@ +# yapf: disable +# ruff: noqa: E501 +# coding=utf-8 +# Copied from +# https://huggingface.co/Snowflake/snowflake-arctic-instruct/blob/main/configuration_arctic.py +""" Arctic model configuration""" + +from dataclasses import asdict, dataclass +from typing import Any, Dict + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +ARCTIC_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "arctic": "https://huggingface.co/Snowflake/snowflake-arctic-instruct/tree/main/config.json", +} + + +@dataclass +class ArcticLoraConfig: + lora_r: int = 64 + lora_alpha: float = 16 + shard_base_weights: bool = False + + +@dataclass +class ArcticQuantizationConfig: + q_bits: int = 8 + rounding: str = "nearest" + mantissa_bits: int = 3 + group_size: int = 128 + + +class ArcticConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ArcticModel`]. It is used to instantiate an + Arctic model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the #TODO(rsamdani): add what model has the default config.. + + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Arctic model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ArcticModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Arctic's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to root per-token, can be also interpreted as the `top-p` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + ```python + >>> from transformers import ArcticModel, ArcticConfig + + >>> # Initializing a Arctic 7B style configuration TODO(rsamdani): verify which model does the default configuration correspond to. + >>> configuration = ArcticConfig() + + >>> # Initializing a model from the Arctic 7B style configuration + >>> model = ArcticModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "arctic" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=1, + num_local_experts=8, + router_aux_loss_coef=0.001, + moe_layer_frequency=2, + parallel_attn_mlp_res=False, + moe_train_capacity_factor=1, + moe_eval_capacity_factor=1, + enable_expert_tensor_parallelism=False, + moe_min_capacity=0, + moe_token_dropping=True, + quantization=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.router_aux_loss_coef = router_aux_loss_coef + self.moe_layer_frequency = moe_layer_frequency + self.moe_train_capacity_factor = moe_train_capacity_factor + self.moe_eval_capacity_factor = moe_eval_capacity_factor + self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism + self.moe_min_capacity = moe_min_capacity + self.moe_token_dropping = moe_token_dropping + self.parallel_attn_mlp_res = parallel_attn_mlp_res + if isinstance(quantization, dict): + self.quantization = ArcticQuantizationConfig(**quantization) + else: + self.quantization = quantization + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "ArcticConfig": + result = super().from_dict(config_dict, **kwargs) + config = result[0] if isinstance(result, tuple) else result + if isinstance(config.quantization, dict): + config.quantization = ArcticQuantizationConfig(**config.quantization) + return result + + def to_dict(self) -> Dict[str, Any]: + ret = super().to_dict() + if isinstance(ret["quantization"], ArcticQuantizationConfig): + ret["quantization"] = asdict(ret["quantization"]) + return ret From 379da6dcb5f5d062d0452b2fc23291e5113dcf04 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Thu, 9 May 2024 16:38:07 -0700 Subject: [PATCH 015/124] [Kernel] [FP8] Improve FP8 linear layer performance (#4691) This PR improves the FP8 performance of linear layers, which had been lacking before (#4118 (comment) and #4118 (comment)). We noticed that CUBLASLt can find a better algorithm if the first dimension of the matrix is greater than 16. So this PR enlarges matrices appropriately during quantization. This improves FP8 performance and removes the performance regression vs. FP16, in many cases exceeding FP16 performance. Here are benchmarks on llama3 70b (ITL numbers for 1000 input and 50 output tokens at fixed qps and at TP 4), all FP8 measurements are for dynamic quantization: qps = 1: 24 ms (FP8, this PR), 32 ms (FP8, previous main), 26 ms (FP16) qps = 2: 26 ms (FP8, this PR), 34ms (FP8, previous main), 28 ms (FP16) qps = 4: 33 ms (FP8, this PR), 44 ms (FP8, previous main), 36 ms (FP16) qps = 6: 46 ms (FP8, this PR), 56 ms (FP8, previous main), 54 ms (FP16) qps = 8: 85 ms (FP8, this PR), 85 ms (FP8, previous main), 138 ms (FP16) --- vllm/_custom_ops.py | 28 ++++++++++++++++++- .../model_executor/layers/quantization/fp8.py | 13 ++++++--- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 5b56437487477..829c47003ad0e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -189,8 +189,34 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, + batch_dim_padding: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensor for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + batch_dim_padding: If specified, pad the first dimension + of the output to at least this value. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + if batch_dim_padding: + shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]) + output = torch.empty(shape, + device=input.device, + dtype=torch.float8_e4m3fn) + else: + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: scale = torch.zeros(1, device=input.device, dtype=torch.float32) vllm_ops.dynamic_scaled_fp8_quant(output, input, scale) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b57e1dde81a5f..ff996741c1d00 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -231,9 +231,14 @@ def apply(self, # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.act_scale is None and x_scale computed from x. # If static, layer.act_scale is scalar and x_scale set to act_scale. - qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale) - - # Fused GEMM_DQ + qinput, x_scale = ops.scaled_fp8_quant(x, + layer.act_scale, + batch_dim_padding=17) + + # Fused GEMM_DQ -- note we padded the input above because + # torch._scaled_mm is more performant for matrices with + # batch dimension > 16. Note that this could change + # in the future. output, _ = torch._scaled_mm( qinput, layer.weight, @@ -243,7 +248,7 @@ def apply(self, bias=bias, ) - return output + return torch.narrow(output, 0, 0, x.shape[0]) def all_close_1d(x: torch.Tensor) -> bool: From c83310174055bb124ea2197885b652efd59b7a0f Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 9 May 2024 17:04:17 -0700 Subject: [PATCH 016/124] [Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support (#4535) --- .buildkite/check-wheel-size.py | 2 +- CMakeLists.txt | 2 +- cmake/utils.cmake | 4 +- csrc/attention/attention_kernels.cu | 286 ++++----- csrc/attention/dtype_fp8.cuh | 16 +- csrc/cache.h | 4 +- csrc/cache_kernels.cu | 143 +++-- .../fp8/{amd_detail => amd}/hip_float8.h | 0 .../fp8/{amd_detail => amd}/hip_float8_impl.h | 0 .../fp8/{amd_detail => amd}/quant_utils.cuh | 59 +- .../fp8/{fp8_cuda_kernels.cu => common.cu} | 0 csrc/quantization/fp8/nvidia/quant_utils.cuh | 568 ++++++++++++++++++ .../fp8_e5m2_kvcache/quant_utils.cuh | 277 --------- tests/kernels/test_attention.py | 4 +- tests/kernels/test_cache.py | 27 +- vllm/_custom_ops.py | 7 +- vllm/utils.py | 2 +- 17 files changed, 843 insertions(+), 558 deletions(-) rename csrc/quantization/fp8/{amd_detail => amd}/hip_float8.h (100%) rename csrc/quantization/fp8/{amd_detail => amd}/hip_float8_impl.h (100%) rename csrc/quantization/fp8/{amd_detail => amd}/quant_utils.cuh (81%) rename csrc/quantization/fp8/{fp8_cuda_kernels.cu => common.cu} (100%) create mode 100644 csrc/quantization/fp8/nvidia/quant_utils.cuh delete mode 100644 csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 90a5e54736cf3..41d9e682572a6 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -1,7 +1,7 @@ import os import zipfile -MAX_SIZE_MB = 100 +MAX_SIZE_MB = 150 def print_top_10_largest_files(zip_file): diff --git a/CMakeLists.txt b/CMakeLists.txt index 47629f036fb09..1c7dfe0c048b0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,7 +167,7 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" - "csrc/quantization/fp8/fp8_cuda_kernels.cu" + "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" "csrc/pybind.cpp") diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 7c71673e36f29..00c81e4d00ad8 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) "Failed to determine torch nvcc compiler flags") if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) - list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") + list(APPEND GPU_FLAGS "-DENABLE_FP8") endif() if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) list(REMOVE_ITEM GPU_FLAGS @@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) list(APPEND GPU_FLAGS "-DUSE_ROCM" - "-DENABLE_FP8_E4M3" + "-DENABLE_FP8" "-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_OPERATORS__" "-fno-gpu-rdc") diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 8b1b5e098015f..41b337dd91d36 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -19,21 +19,17 @@ #include #include #include +#include #include "attention_dtypes.h" #include "attention_utils.cuh" -#if defined(ENABLE_FP8_E5M2) -#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh" -#elif defined(ENABLE_FP8_E4M3) -#include "../quantization/fp8/amd_detail/quant_utils.cuh" -#endif - -#include - #ifdef USE_ROCM #include + #include "../quantization/fp8/amd/quant_utils.cuh" typedef __hip_bfloat16 __nv_bfloat16; +#else + #include "../quantization/fp8/nvidia/quant_utils.cuh" #endif #ifndef USE_ROCM @@ -92,7 +88,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_KV_CACHE, + vllm::Fp8KVCacheDataType KV_DTYPE, int PARTITION_SIZE = 0> // Zero means no partitioning. __device__ void paged_attention_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -157,9 +153,7 @@ __device__ void paged_attention_kernel( constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using Quant_vec = typename Vec::Type; -#endif constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; @@ -223,21 +217,14 @@ __device__ void paged_attention_kernel( const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; - if constexpr (IS_FP8_KV_CACHE) { -#if defined(ENABLE_FP8_E5M2) - Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - // Vector conversion from Quant_vec to K_vec. - k_vecs[j] = fp8_e5m2_unscaled::vec_conversion(k_vec_quant); -#elif defined(ENABLE_FP8_E4M3) - Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - // Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k - // cache vec to k vec in higher precision (FP16, BFloat16, etc.) - k_vecs[j] = fp8_e4m3::scaled_vec_conversion(k_vec_quant, kv_scale); -#else - assert(false); -#endif - } else { + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } else { + // Vector conversion from Quant_vec to K_vec. + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert(k_vec_quant, kv_scale); } } @@ -312,9 +299,7 @@ __device__ void paged_attention_kernel( constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; using L_vec = typename Vec::Type; -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using V_quant_vec = typename Vec::Type; -#endif using Float_L_vec = typename FloatVec::Type; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; @@ -348,21 +333,13 @@ __device__ void paged_attention_kernel( if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec; - if constexpr (IS_FP8_KV_CACHE) { -#if defined(ENABLE_FP8_E5M2) + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + v_vec = *reinterpret_cast(v_ptr + offset); + } else { V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. - v_vec = fp8_e5m2_unscaled::vec_conversion(v_quant_vec); -#elif defined(ENABLE_FP8_E4M3) - V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); - // Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert - // FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.) - v_vec = fp8_e4m3::scaled_vec_conversion(v_quant_vec, kv_scale); -#else - assert(false); -#endif - } else { - v_vec = *reinterpret_cast(v_ptr + offset); + v_vec = fp8::scaled_convert(v_quant_vec, kv_scale); } if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the context, @@ -448,7 +425,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_KV_CACHE> + vllm::Fp8KVCacheDataType KV_DTYPE> __global__ void paged_attention_v1_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -464,7 +441,7 @@ __global__ void paged_attention_v1_kernel( const int kv_block_stride, const int kv_head_stride, const float kv_scale) { - paged_attention_kernel( + paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); @@ -477,7 +454,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_KV_CACHE, + vllm::Fp8KVCacheDataType KV_DTYPE, int PARTITION_SIZE> __global__ void paged_attention_v2_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -496,7 +473,7 @@ __global__ void paged_attention_v2_kernel( const int kv_block_stride, const int kv_head_stride, const float kv_scale) { - paged_attention_kernel( + paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); @@ -606,9 +583,9 @@ __global__ void paged_attention_v2_reduce_kernel( #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \ + KV_DTYPE>), shared_mem_size); \ vllm::paged_attention_v1_kernel<<>>( \ + KV_DTYPE><<>>( \ out_ptr, \ query_ptr, \ key_cache_ptr, \ @@ -629,7 +606,7 @@ template< typename T, typename CACHE_T, int BLOCK_SIZE, - bool IS_FP8_KV_CACHE, + vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128> void paged_attention_v1_launcher( torch::Tensor& out, @@ -706,36 +683,36 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - paged_attention_v1_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - seq_lens, \ - max_seq_len, \ - alibi_slopes, \ +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ + paged_attention_v1_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + num_kv_heads, \ + scale, \ + block_tables, \ + seq_lens, \ + max_seq_len, \ + alibi_slopes, \ kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ - switch (block_size) { \ - case 8: \ - CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ - break; \ - case 16: \ - CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ - break; \ - case 32: \ - CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v1( @@ -752,65 +729,44 @@ void paged_attention_v1( const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { - if (kv_cache_dtype == "auto") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else if (kv_cache_dtype == "fp8") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V1_LAUNCHER_BLOCK_SIZE) } -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - seq_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, \ - kv_scale); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - out_ptr, \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - seq_lens_ptr, \ +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + num_kv_heads, \ + scale, \ + block_tables_ptr, \ + seq_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride, \ + kv_scale); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + seq_lens_ptr, \ max_num_partitions); template< typename T, typename CACHE_T, int BLOCK_SIZE, - bool IS_FP8_KV_CACHE, + vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128, int PARTITION_SIZE = 512> void paged_attention_v2_launcher( @@ -897,39 +853,39 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - paged_attention_v2_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - seq_lens, \ - max_seq_len, \ - alibi_slopes, \ +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ + paged_attention_v2_launcher( \ + out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ + query, \ + key_cache, \ + value_cache, \ + num_kv_heads, \ + scale, \ + block_tables, \ + seq_lens, \ + max_seq_len, \ + alibi_slopes, \ kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v2( @@ -949,29 +905,7 @@ void paged_attention_v2( const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { - if (kv_cache_dtype == "auto") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else if (kv_cache_dtype == "fp8") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) } #undef WARP_SIZE diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh index d11dee91ebe87..2b32ce372a64f 100644 --- a/csrc/attention/dtype_fp8.cuh +++ b/csrc/attention/dtype_fp8.cuh @@ -3,14 +3,21 @@ #include "attention_generic.cuh" #include -#ifdef ENABLE_FP8_E5M2 +#ifdef ENABLE_FP8 +#ifndef USE_ROCM #include -#endif +#endif // USE_ROCM +#endif // ENABLE_FP8 namespace vllm { -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) -// fp8 vector types for quantization of kv cache +enum class Fp8KVCacheDataType { + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, +}; + +// fp8 vector types for quantization of kv cache template<> struct Vec { using Type = uint8_t; @@ -30,6 +37,5 @@ template<> struct Vec { using Type = uint2; }; -#endif // ENABLE_FP8_E5M2 } // namespace vllm diff --git a/csrc/cache.h b/csrc/cache.h index 212a3bf3ddc1c..8c176c452425e 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -34,5 +34,7 @@ void reshape_and_cache_flash( // Just for unittest void convert_fp8( + torch::Tensor& dst_cache, torch::Tensor& src_cache, - torch::Tensor& dst_cache); + const float scale, + const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 76db96f099c69..e5b74da6ad068 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -4,10 +4,11 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#if defined(ENABLE_FP8_E5M2) -#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" -#elif defined(ENABLE_FP8_E4M3) -#include "quantization/fp8/amd_detail/quant_utils.cuh" + +#ifdef USE_ROCM +#include "quantization/fp8/amd/quant_utils.cuh" +#else +#include "quantization/fp8/nvidia/quant_utils.cuh" #endif #include @@ -149,7 +150,7 @@ void copy_blocks( namespace vllm { -template +template __global__ void reshape_and_cache_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] @@ -194,19 +195,12 @@ __global__ void reshape_and_cache_kernel( + block_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; - if constexpr (is_fp8_kv_cache) { -#if defined(ENABLE_FP8_E5M2) - key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_key); - value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_value); -#elif defined(ENABLE_FP8_E4M3) - key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion(tgt_key, kv_scale); - value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion(tgt_value, kv_scale); -#else - assert(false); -#endif - } else { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { key_cache[tgt_key_idx] = tgt_key; value_cache[tgt_value_idx] = tgt_value; + } else { + key_cache[tgt_key_idx] = fp8::scaled_convert(tgt_key, kv_scale); + value_cache[tgt_value_idx] = fp8::scaled_convert(tgt_value, kv_scale); } } } @@ -248,19 +242,22 @@ __global__ void reshape_and_cache_flash_kernel( } } // namespace vllm -#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ - vllm::reshape_and_cache_kernel<<>>( \ - reinterpret_cast(key.data_ptr()), \ - reinterpret_cast(value.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - slot_mapping.data_ptr(), \ - key_stride, \ - value_stride, \ - num_heads, \ - head_size, \ - block_size, \ - x, \ +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_kernel<<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), \ + key_stride, \ + value_stride, \ + num_heads, \ + head_size, \ + block_size, \ + x, \ kv_scale); void reshape_and_cache( @@ -285,25 +282,8 @@ void reshape_and_cache( dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (kv_cache_dtype == "auto") { - if (key.dtype() == at::ScalarType::Float) { - CALL_RESHAPE_AND_CACHE(float, float, false); - } else if (key.dtype() == at::ScalarType::Half) { - CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false); - } else if (key.dtype() == at::ScalarType::BFloat16) { - CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); - } - } else if (kv_cache_dtype == "fp8") { - if (key.dtype() == at::ScalarType::Float) { - CALL_RESHAPE_AND_CACHE(float, uint8_t, true); - } else if (key.dtype() == at::ScalarType::Half) { - CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true); - } else if (key.dtype() == at::ScalarType::BFloat16) { - CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE) } void reshape_and_cache_flash( @@ -353,35 +333,34 @@ void reshape_and_cache_flash( namespace vllm { -template +template __global__ void convert_fp8_kernel( const Tin* __restrict__ src_cache, Tout* __restrict__ dst_cache, + const float kv_scale, const int64_t block_stride) { const int64_t block_idx = blockIdx.x; for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { int64_t idx = block_idx * block_stride + i; -#if defined(ENABLE_FP8_E5M2) - dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion(src_cache[idx]); -#elif defined(ENABLE_FP8_E4M3) - dst_cache[idx] = fp8_e4m3::vec_conversion(src_cache[idx]); -#else - assert(false); -#endif + dst_cache[idx] = fp8::scaled_convert(src_cache[idx], kv_scale); } } } // namespace vllm -#define CALL_CONVERT_FP8(Tout, Tin) \ - vllm::convert_fp8_kernel<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst_cache.data_ptr()), \ +#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ + vllm::convert_fp8_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), \ + kv_scale, \ block_stride); +// Only for testing. void convert_fp8( + torch::Tensor& dst_cache, torch::Tensor& src_cache, - torch::Tensor& dst_cache) + const float kv_scale, + const std::string& kv_cache_dtype) { torch::Device src_device = src_cache.device(); torch::Device dst_device = dst_cache.device(); @@ -399,17 +378,35 @@ void convert_fp8( dim3 block(std::min(block_stride, int64_t(512))); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (src_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(uint8_t, float); - } else if (src_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint8_t, uint16_t); - } else if (src_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(uint8_t, __nv_bfloat16); - } else if (dst_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(float, uint8_t); - } else if (dst_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint16_t, uint8_t); - } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(__nv_bfloat16, uint8_t); + if (kv_cache_dtype == "auto") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } + } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } + } else { + TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); } } diff --git a/csrc/quantization/fp8/amd_detail/hip_float8.h b/csrc/quantization/fp8/amd/hip_float8.h similarity index 100% rename from csrc/quantization/fp8/amd_detail/hip_float8.h rename to csrc/quantization/fp8/amd/hip_float8.h diff --git a/csrc/quantization/fp8/amd_detail/hip_float8_impl.h b/csrc/quantization/fp8/amd/hip_float8_impl.h similarity index 100% rename from csrc/quantization/fp8/amd_detail/hip_float8_impl.h rename to csrc/quantization/fp8/amd/hip_float8_impl.h diff --git a/csrc/quantization/fp8/amd_detail/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh similarity index 81% rename from csrc/quantization/fp8/amd_detail/quant_utils.cuh rename to csrc/quantization/fp8/amd/quant_utils.cuh index 894160972d9f4..df0329f79d361 100644 --- a/csrc/quantization/fp8/amd_detail/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -5,12 +5,17 @@ #include #include +#include "../../../attention/dtype_fp8.cuh" #include "../../../attention/dtype_float32.cuh" #include "../../../attention/dtype_bfloat16.cuh" namespace vllm { -namespace fp8_e4m3 { +#ifdef USE_ROCM + +namespace fp8 { +#ifdef ENABLE_FP8 + template __inline__ __device__ Tout vec_conversion(const Tin& x) { @@ -512,6 +517,58 @@ __inline__ __device__ float4 scaled_vec_conversion(const uint3 float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); return res; } +#endif // ENABLE_FP8 +template +__inline__ __device__ Tout convert(const Tin &x) { +#ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return vec_conversion(x); + } +#endif + assert(false); } + +template +__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { +#ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return scaled_vec_conversion(x, scale); + } +#endif + assert(false); +} + +// The following macro is used to dispatch the conversion function based on the +// data type of the key and value cache. The FN is a macro that calls a function +// with template. +#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } + +} // fp8 +#endif // USE_ROCM } // namespace vllm diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/common.cu similarity index 100% rename from csrc/quantization/fp8/fp8_cuda_kernels.cu rename to csrc/quantization/fp8/common.cu diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/fp8/nvidia/quant_utils.cuh new file mode 100644 index 0000000000000..4eeacf7a6f9d9 --- /dev/null +++ b/csrc/quantization/fp8/nvidia/quant_utils.cuh @@ -0,0 +1,568 @@ +#pragma once + +#include "../../../attention/attention_dtypes.h" +#include +#include +#include +#include + +namespace vllm { +#ifndef USE_ROCM + +namespace fp8 { +#ifdef ENABLE_FP8 + +#if 0 // Disable the following code to reduce the binary size. +template +__inline__ __device__ Tout +vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t vec_conversion( + const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); + tmp.u16[0] = res.x; + tmp.u16[1] = res.y; + return tmp.u32; +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a, fp8_type); + tmp.u32[1] = + vec_conversion((uint16_t)(a >> 16U), fp8_type); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x, fp8_type); + tmp.u64[1] = vec_conversion(a.y, fp8_type); + return tmp.u64x2; +} + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>( + const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { + // Note there is no direct convert function from fp8 to bf16. + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + // half -> float -> bf16 + float tmp = half_to_float(res.x); + return __float2bfloat16(tmp); +} + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type); + res.y = + vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x, fp8_type); + tmp2 = vec_conversion(a.y, fp8_type); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float +vec_conversion(const uint8_t &a, + const __nv_fp8_interpretation_t fp8_type) { + // fp8 -> half + uint16_t tmp = vec_conversion(a, fp8_type); + // half -> float + return half_to_float(tmp); +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + // fp8x2 -> half2 + uint32_t tmp = vec_conversion(a, fp8_type); + // half2 -> float2 + return half2_to_float2(tmp); +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ res; + res.x = vec_conversion((uint16_t)a, fp8_type); + res.y = vec_conversion((uint16_t)(a >> 16U), fp8_type); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x, fp8_type); + tmp2 = vec_conversion(a.y, fp8_type); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + __half_raw tmp; + tmp.x = a; + __nv_fp8_storage_t res = + __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8( + __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type); + return (uint8_t)res; +#endif +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const float &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp = vec_conversion(a, fp8_type); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + +template <> +__inline__ __device__ uint32_t vec_conversion( + const float2 &a, const __nv_fp8_interpretation_t fp8_type) { + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +template <> +__inline__ __device__ uint2 vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val, fp8_type); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val, fp8_type); + + return b; +} + +template <> +__inline__ __device__ float4 vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +template <> +__inline__ __device__ uint4 vec_conversion( + const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { + uint4 b; + b.x = vec_conversion(a.x, fp8_type); + b.y = vec_conversion(a.y, fp8_type); + b.z = vec_conversion(a.z, fp8_type); + b.w = vec_conversion(a.w, fp8_type); + return b; +} + +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>( + const float2 &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 b; + from_float(b, a); + return b; +} + +template <> +__inline__ __device__ bf16_4_t vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t b; + from_float(b, a); + return b; +} + +template <> +__inline__ __device__ bf16_8_t vec_conversion( + const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_8_t b; + from_float(b, a); + return b; +} +#endif + +/* Scaled and vectorized conversions, for data exchange between high and low + precision domains Convention of the scale in API, e.g: FP8_data = + Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 + Dequant(FP8) * scale => HP + */ + +template +__inline__ __device__ Tout scaled_vec_conversion( + const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) { + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t scaled_vec_conversion( + const uint8_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type); + return float_to_half(half_to_float(tmp.x) * scale); +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t scaled_vec_conversion( + const uint16_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); + tmp.u16[0] = float_to_half(half_to_float(res.x) * scale); + tmp.u16[1] = float_to_half(half_to_float(res.y) * scale); + return tmp.u32; +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 scaled_vec_conversion( + const uint32_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = + scaled_vec_conversion((uint16_t)a, scale, fp8_type); + tmp.u32[1] = scaled_vec_conversion((uint16_t)(a >> 16U), + scale, fp8_type); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 +scaled_vec_conversion(const uint2 &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale, fp8_type); + tmp.u64[1] = scaled_vec_conversion(a.y, scale, fp8_type); + return tmp.u64x2; +} + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 +scaled_vec_conversion<__nv_bfloat16, uint8_t>( + const uint8_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + // Note there is no direct convert function from fp8 to bf16. + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + // half -> float -> bf16 + float tmp = half_to_float(res.x); + return __float2bfloat16(tmp * scale); +} + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 +scaled_vec_conversion<__nv_bfloat162, uint16_t>( + const uint16_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, + fp8_type); + res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), + scale, fp8_type); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t scaled_vec_conversion( + const uint32_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, + fp8_type); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), + scale, fp8_type); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t scaled_vec_conversion( + const uint2 &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); + tmp2 = scaled_vec_conversion(a.y, scale, fp8_type); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float scaled_vec_conversion( + const uint8_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + uint16_t tmp = res.x; + + // half -> float + return half_to_float(tmp) * scale; +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 scaled_vec_conversion( + const uint16_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + // fp8x2 -> half2 + uint32_t tmp = scaled_vec_conversion(a, scale, fp8_type); + // half2 -> float2 + return half2_to_float2(tmp); +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ scaled_vec_conversion( + const uint32_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale, fp8_type); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale, + fp8_type); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ scaled_vec_conversion( + const uint2 &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); + tmp2 = scaled_vec_conversion(a.y, scale, fp8_type); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const uint16_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = + __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const __nv_bfloat16 &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale, + __NV_SATFINITE, fp8_type); + return (uint8_t)res; +#endif +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const float &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = + __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 scaled_vec_conversion( + const uint32_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp = scaled_vec_conversion(a, scale, fp8_type); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} +#endif // ENABLE_FP8 + +template +__inline__ __device__ Tout convert(const Tin &x) { +#if 0 // Disable the following code to reduce the binary size. + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return vec_conversion(x, __NV_E4M3); + } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { + return vec_conversion(x, __NV_E5M2); + } +#endif + assert(false); +} + +template +__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { +#ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return scaled_vec_conversion(x, scale, __NV_E4M3); + } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { + return scaled_vec_conversion(x, scale, __NV_E5M2); + } +#endif + assert(false); +} + +// The following macro is used to dispatch the conversion function based on the +// data type of the key and value cache. The FN is a macro that calls a function +// with template. +#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else if (KV_DTYPE == "fp8_e5m2") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } + +} // namespace fp8 +#endif // not USE_ROCM +} // namespace vllm diff --git a/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh b/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh deleted file mode 100644 index 9bcab25db03cf..0000000000000 --- a/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh +++ /dev/null @@ -1,277 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include "../../attention/attention_dtypes.h" -#include "../../attention/dtype_float32.cuh" -#include "../../attention/dtype_float16.cuh" -#include "../../attention/dtype_bfloat16.cuh" - - -namespace vllm { -#ifdef ENABLE_FP8_E5M2 -namespace fp8_e5m2_unscaled { - -template -__inline__ __device__ Tout vec_conversion(const Tin& x) -{ - return x; -} - -// fp8 -> half -template<> -__inline__ __device__ uint16_t vec_conversion(const uint8_t& a) -{ - __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); - return res.x; -} - -// fp8x2 -> half2 -template<> -__inline__ __device__ uint32_t vec_conversion(const uint16_t& a) -{ - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2); - tmp.u16[0] = res.x; - tmp.u16[1] = res.y; - return tmp.u32; -} - -// fp8x4 -> half2x2 -template<> -__inline__ __device__ uint2 vec_conversion(const uint32_t& a) -{ - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = vec_conversion((uint16_t)a); - tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); - return tmp.u32x2; -} - -// fp8x8 -> half2x4 -template<> -__inline__ __device__ uint4 vec_conversion(const uint2& a) -{ - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = vec_conversion(a.x); - tmp.u64[1] = vec_conversion(a.y); - return tmp.u64x2; -} - -// fp8 -> __nv_bfloat16 -template<> -__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) -{ - // Note there is no direct convert function from fp8 to bf16. - // fp8 -> half - __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); - // half -> float -> bf16 - float tmp = half_to_float(res.x); - return __float2bfloat16(tmp); -} - -// fp8x2 -> __nv_bfloat162 -template<> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) -{ - __nv_bfloat162 res; - res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); - res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); - return res; -} - -// fp8x4 -> bf16_4_t -template<> -__inline__ __device__ bf16_4_t vec_conversion(const uint32_t& a) -{ - bf16_4_t res; - res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); - res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); - return res; -} - -// fp8x8 -> bf16_8_t -template<> -__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) -{ - bf16_4_t tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - bf16_8_t res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; -} - -// fp8 -> float -template<> -__inline__ __device__ float vec_conversion(const uint8_t& a) -{ - // fp8 -> half - uint16_t tmp = vec_conversion(a); - // half -> float - return half_to_float(tmp); -} - -// fp8x2 -> float2 -template<> -__inline__ __device__ float2 vec_conversion(const uint16_t& a) -{ - // fp8x2 -> half2 - uint32_t tmp = vec_conversion(a); - // half2 -> float2 - return half2_to_float2(tmp); -} - -// fp8x4 -> float4 -template<> -__inline__ __device__ Float4_ vec_conversion(const uint32_t& a) -{ - Float4_ res; - res.x = vec_conversion((uint16_t)a); - res.y = vec_conversion((uint16_t)(a >> 16U)); - return res; -} - -// fp8x8 -> float8 -template<> -__inline__ __device__ Float8_ vec_conversion(const uint2& a) -{ - Float4_ tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - Float8_ res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; -} - - -// half -> fp8 -template<> -__inline__ __device__ uint8_t vec_conversion(const uint16_t& a) -{ - __half_raw tmp; - tmp.x = a; - __nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2); - return (uint8_t)res; -} - -// bf16 -> fp8 -template<> -__inline__ __device__ uint8_t vec_conversion(const __nv_bfloat16& a) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - assert(false); -#else - __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2); - return (uint8_t)res; -#endif -} - -// float -> fp8 -template<> -__inline__ __device__ uint8_t vec_conversion(const float& a) -{ - __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2); - return (uint8_t)res; -} - -// fp8x4 -> float4 -template<> -__inline__ __device__ float4 vec_conversion(const uint32_t& a) -{ - Float4_ tmp = vec_conversion(a); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; -} - - -template<> -__inline__ __device__ uint32_t vec_conversion(const float2& a) -{ - union { - half2 float16; - uint32_t uint32; - }; - - float16 = __float22half2_rn(a); - return uint32; -} - -template<> -__inline__ __device__ uint2 vec_conversion(const Float4_& a) -{ - uint2 b; - float2 val; - val.x = a.x.x; - val.y = a.x.y; - b.x = vec_conversion(val); - - val.x = a.y.x; - val.y = a.y.y; - b.y = vec_conversion(val); - - return b; -} - -template<> -__inline__ __device__ float4 vec_conversion(const Float4_& a) -{ - float4 b; - b.x = a.x.x; - b.y = a.x.y; - b.z = a.y.x; - b.w = a.y.y; - return b; -} - -template<> -__inline__ __device__ uint4 vec_conversion(const Float8_& a) -{ - uint4 b; - b.x = vec_conversion(a.x); - b.y = vec_conversion(a.y); - b.z = vec_conversion(a.z); - b.w = vec_conversion(a.w); - return b; -} - -template<> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) { - __nv_bfloat162 b; - from_float(b, a); - return b; -} - -template<> -__inline__ __device__ bf16_4_t vec_conversion(const Float4_ &a) { - bf16_4_t b; - from_float(b, a); - return b; -} - -template<> -__inline__ __device__ bf16_8_t vec_conversion(const Float8_ &a) { - bf16_8_t b; - from_float(b, a); - return b; -} - -} // namespace fp8_e5m2_unscaled -#endif // ENABLE_FP8_E5M2 -} // namespace vllm diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 84539205e0ae3..28496f187d466 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -236,14 +236,14 @@ def test_paged_attention( dequantized_key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) - ops.convert_fp8(key_cache, dequantized_key_cache) + ops.convert_fp8(dequantized_key_cache, key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) - ops.convert_fp8(value_cache, dequantized_value_cache) + ops.convert_fp8(dequantized_value_cache, value_cache) value_cache = dequantized_value_cache ref_output = torch.empty_like(query) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 4cae15c79c489..9f0cb60dc16e2 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -5,8 +5,6 @@ import torch from vllm import _custom_ops as ops -from vllm._C import cache_ops -from vllm.utils import is_hip COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -25,6 +23,8 @@ CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] + +# We assume fp8 is always enabled for testing. KV_CACHE_DTYPE = ["auto", "fp8"] @@ -124,8 +124,6 @@ def test_reshape_and_cache( device: str, kv_cache_dtype: str, ) -> None: - if not is_hip() and kv_cache_dtype == "fp8": - pytest.skip() # This test is not tuned for e5m2 cuda precision random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -149,9 +147,9 @@ def test_reshape_and_cache( # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(key_cache, cloned_key_cache) + ops.convert_fp8(cloned_key_cache, key_cache) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(value_cache, cloned_value_cache) + ops.convert_fp8(cloned_value_cache, value_cache) else: cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() @@ -165,9 +163,9 @@ def test_reshape_and_cache( if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(key_cache, result_key_cache) + ops.convert_fp8(result_key_cache, key_cache) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(value_cache, result_value_cache) + ops.convert_fp8(result_value_cache, value_cache) # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) @@ -255,8 +253,8 @@ def test_reshape_and_cache_flash( cloned_value_cache = value_cache.clone() # Call the reshape_and_cache kernel. - cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype) + ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype) # Run the reference implementation. block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') @@ -299,8 +297,6 @@ def test_swap_blocks( ) -> None: if kv_cache_dtype == "fp8" and "cpu" in direction: pytest.skip() - if not is_hip() and kv_cache_dtype == "fp8": - pytest.skip() # This test is not tuned for e5m2 cuda precision random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -348,7 +344,6 @@ def test_swap_blocks( dist_value_caches[0][dst].cpu()) -@pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3") @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @@ -357,7 +352,7 @@ def test_swap_blocks( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_fp8_conversion( +def test_fp8_e4m3_conversion( num_heads: int, head_size: int, block_size: int, @@ -377,9 +372,9 @@ def test_fp8_conversion( cache.uniform_(low, high) cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) - ops.convert_fp8(cache, cache_fp8) + ops.convert_fp8(cache_fp8, cache) converted_cache = torch.empty_like(cache) - ops.convert_fp8(cache_fp8, converted_cache) + ops.convert_fp8(converted_cache, cache_fp8) assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 829c47003ad0e..35a9f6329fc42 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -270,8 +270,11 @@ def swap_blocks(src: torch.Tensor, dst: torch.Tensor, vllm_cache_ops.swap_blocks(src, dst, block_mapping) -def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None: - vllm_cache_ops.convert_fp8(output, input) +def convert_fp8(output: torch.Tensor, + input: torch.Tensor, + scale: float = 1.0, + kv_dtype: str = "fp8") -> None: + vllm_cache_ops.convert_fp8(output, input, scale, kv_dtype) #TODO: cuda_utils, custom_ar diff --git a/vllm/utils.py b/vllm/utils.py index 6479a8dab320a..f0e71f5e99b64 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -329,7 +329,7 @@ def _generate_random_fp8( from vllm import _custom_ops as ops tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) - ops.convert_fp8(tensor_tmp, tensor) + ops.convert_fp8(tensor, tensor_tmp) del tensor_tmp From 208b71bcc1b94df1fdd2fc10da3e04c706340188 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 9 May 2024 19:48:43 -0700 Subject: [PATCH 017/124] [Core][Distributed] refactor pynccl (#4591) [Core][Distributed] refactor pynccl to hold multiple communicators (#4591) --- tests/distributed/test_pynccl.py | 78 ++--- vllm/distributed/communication_op.py | 28 +- .../device_communicators/pynccl.py | 294 +++++------------- .../device_communicators/pynccl_utils.py | 66 ---- .../device_communicators/pynccl_wrapper.py | 258 +++++++++++++++ vllm/distributed/parallel_state.py | 131 ++++---- vllm/worker/model_runner.py | 25 +- vllm/worker/worker.py | 21 -- 8 files changed, 467 insertions(+), 434 deletions(-) delete mode 100644 vllm/distributed/device_communicators/pynccl_utils.py create mode 100644 vllm/distributed/device_communicators/pynccl_wrapper.py diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index b6f461b76ed03..b3e30a0434423 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -1,15 +1,15 @@ import multiprocessing +import os import pytest import torch -import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils -from vllm.distributed.communication_op import tensor_model_parallel_all_reduce -from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, - ncclGetUniqueId) -from vllm.distributed.parallel_state import ( - ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group, - init_distributed_environment, with_pynccl_for_all_reduce) +from vllm.distributed.communication_op import ( # noqa + graph_capture_mode, tensor_model_parallel_all_reduce) +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) from vllm.utils import update_environment_variables @@ -41,6 +41,9 @@ def worker_fn_wrapper(fn): # and update the environment variables in the function def wrapped_fn(env): update_environment_variables(env) + local_rank = os.environ['LOCAL_RANK'] + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) init_distributed_environment() fn() @@ -49,11 +52,13 @@ def wrapped_fn(env): @worker_fn_wrapper def worker_fn(): - comm = NCCLCommunicator() - tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) - comm.all_reduce(tensor) + pynccl_comm = PyNcclCommunicator() + tensor = torch.ones(16, 1024, 1024, + dtype=torch.float32).cuda(pynccl_comm.rank) + with pynccl_comm.change_state(enable=True): + pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() - assert result == comm.world_size + assert result == pynccl_comm.world_size @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -70,37 +75,35 @@ def multiple_tp_worker_fn(): torch.distributed.new_group(ranks=[2, 3], backend="gloo") ] group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] - comm = NCCLCommunicator(group=group, device=device) + pynccl_comm = PyNcclCommunicator(group=group, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - # two groups can communicate independently - if torch.distributed.get_rank() in [0, 1]: - comm.all_reduce(tensor) - comm.all_reduce(tensor) - result = tensor.mean().cpu().item() - assert result == 4 - else: - comm.all_reduce(tensor) - result = tensor.mean().cpu().item() - assert result == 2 + with pynccl_comm.change_state(enable=True): + # two groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + pynccl_comm.all_reduce(tensor) + pynccl_comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 4 + else: + pynccl_comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 2 @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test.") def test_pynccl_multiple_tp(): # this tests pynccl for multiple tp groups, in a standalone way - # i.e. call `comm.all_reduce` directly + # i.e. call `pynccl_comm.all_reduce` directly distributed_run(multiple_tp_worker_fn, 4) @worker_fn_wrapper def multiple_tp_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") - torch.cuda.set_device(torch.distributed.get_rank()) ensure_model_parallel_initialized(2, 2) - pynccl_utils.init_process_group( - group=get_tensor_model_parallel_cpu_group()) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - with with_pynccl_for_all_reduce(): + with graph_capture_mode(): # two tp groups can communicate independently if torch.distributed.get_rank() in [0, 1]: tensor = tensor_model_parallel_all_reduce(tensor) @@ -125,19 +128,21 @@ def test_pynccl_multiple_tp_with_vllm(): def worker_fn_with_cudagraph(): with torch.no_grad(): graph = torch.cuda.CUDAGraph() - comm = NCCLCommunicator() + pynccl_comm = PyNcclCommunicator() # run something in the default stream to initialize torch engine - a = torch.ones((4, 4), device=f'cuda:{comm.rank}') + a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') torch.cuda.synchronize() - with torch.cuda.graph(graph, stream=comm.stream): + with torch.cuda.graph( + graph, stream=pynccl_comm.stream), pynccl_comm.change_state( + enable=True): # operation during the graph capture is recorded but not executed # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa - comm.all_reduce(a) - comm.stream.synchronize() - assert a.mean().cpu().item() == comm.world_size**0 + pynccl_comm.all_reduce(a) + pynccl_comm.stream.synchronize() + assert a.mean().cpu().item() == pynccl_comm.world_size**0 graph.replay() - comm.stream.synchronize() - assert a.mean().cpu().item() == comm.world_size**1 + pynccl_comm.stream.synchronize() + assert a.mean().cpu().item() == pynccl_comm.world_size**1 @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -147,7 +152,8 @@ def test_pynccl_with_cudagraph(): def test_ncclGetUniqueId(): - unique_id = ncclGetUniqueId() + lib = NCCLLibrary() + unique_id = lib.ncclGetUniqueId() # `list(unique_id.internal)` is something like this: # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0, # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 80d03129bdb9b..32ab5694e5390 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,4 +1,5 @@ from collections import namedtuple +from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -8,7 +9,26 @@ get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - is_pynccl_enabled_for_all_reduce) + get_tp_pynccl_communicator) + + +@contextmanager +def graph_capture_mode(): + # In graph capture, we have to be very careful about the collective + # operations. The current status is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | enabled | enabled | + # PyNccl | disabled| enabled | + # torch.distributed | enabled | disabled| + # + # Note that custom allreduce will have a runtime check, if the tensor size + # is too large, it will fallback to the next available option. + pynccl_comm = get_tp_pynccl_communicator() + assert pynccl_comm is not None + with pynccl_comm.change_state(enable=True, + stream=torch.cuda.current_stream()): + yield def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: @@ -23,7 +43,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: TLDR: always assume this function modifies its input, but use the return value as the output. """ - from vllm.distributed.device_communicators import pynccl_utils from vllm.distributed.device_communicators.custom_all_reduce import ( custom_all_reduce) @@ -33,8 +52,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: out = custom_all_reduce(input_) if out is not None: return out - if is_pynccl_enabled_for_all_reduce(): - pynccl_utils.all_reduce(input_) + pynccl_comm = get_tp_pynccl_communicator() + if (pynccl_comm is not None and not pynccl_comm.disabled): + pynccl_comm.all_reduce(input_) else: torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 758994352e3de..168d4cc2df8a6 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -1,26 +1,4 @@ -# This file is a pure Python wrapper for the NCCL library. -# The main purpose is to use NCCL combined with CUDA graph. -# Before writing this script, we tried the following approach: -# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself -# often gets stuck when initializing the NCCL communicator. -# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` -# contains many other potential cuda APIs, that are not allowed during -# capturing the CUDA graph. For further details, please check -# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . -# -# Another rejected idea is to write a C/C++ binding for NCCL. It is usually -# doable, but we often encounter issues related with nccl versions, and need -# to switch between different versions of NCCL. See -# https://github.com/NVIDIA/nccl/issues/1234 for more details. -# A C/C++ binding is not flexible enough to handle this. It requires -# recompilation of the code every time we want to switch between different -# versions. This current implementation, with a **pure** Python wrapper, is -# more flexible. We can easily switch between different versions of NCCL by -# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` -# variable in the code. - -import ctypes -import platform +from contextlib import contextmanager from typing import Optional, Union # ===================== import region ===================== @@ -28,217 +6,70 @@ import torch.distributed as dist from torch.distributed import ProcessGroup, ReduceOp +from vllm.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, + ncclRedOpTypeEnum, ncclUniqueId) from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank from vllm.logger import init_logger -from vllm.utils import find_nccl_library, nccl_integrity_check logger = init_logger(__name__) -so_file = find_nccl_library() - -try: - # load the library in another process. - # if it core dumps, it will not crash the current process - nccl_integrity_check(so_file) - nccl = ctypes.CDLL(so_file) -except Exception as e: - logger.error( - "Failed to load NCCL library from %s ." - "It is expected if you are not running on NVIDIA/AMD GPUs." - "Otherwise, the nccl library might not exist, be corrupted " - "or it does not support the current platform %s." - "One solution is to download libnccl2 version 2.18 from " - "https://developer.download.nvidia.com/compute/cuda/repos/ " - "and extract the libnccl.so.2 file. If you already have the " - "library, please set the environment variable VLLM_NCCL_SO_PATH" - " to point to the correct nccl library path.", so_file, - platform.platform()) - raise e - -# === export types and functions from nccl to Python === -# for the original nccl definition, please check -# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in - -ncclResult_t = ctypes.c_int - -_c_ncclGetErrorString = nccl.ncclGetErrorString -_c_ncclGetErrorString.restype = ctypes.c_char_p -_c_ncclGetErrorString.argtypes = [ncclResult_t] - - -def NCCL_CHECK(result: ncclResult_t) -> None: - if result != 0: - error_str = _c_ncclGetErrorString(result) - error_str = error_str.decode("utf-8") - raise RuntimeError(f"NCCL error: {error_str}") - - -# equivalent to c declaration: -# ncclResult_t ncclGetVersion(int *version); -_c_ncclGetVersion = nccl.ncclGetVersion -_c_ncclGetVersion.restype = ctypes.c_int -_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] - - -def ncclGetVersion() -> str: - version = ctypes.c_int() - NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version))) - # something like 21903 --> "2.19.3" - version_str = str(version.value) - major = version_str[0].lstrip("0") - minor = version_str[1:3].lstrip("0") - patch = version_str[3:].lstrip("0") - return f"{major}.{minor}.{patch}" - - -class NcclUniqueId(ctypes.Structure): - _fields_ = [("internal", ctypes.c_byte * 128)] - - -# equivalent to c declaration: -# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); -_c_ncclGetUniqueId = nccl.ncclGetUniqueId -_c_ncclGetUniqueId.restype = ctypes.c_int -_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)] - - -def ncclGetUniqueId() -> NcclUniqueId: - unique_id = NcclUniqueId() - NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id))) - return unique_id - - -# equivalent to c declaration: -# ncclResult_t ncclCommInitRank( -# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); -# note that ncclComm_t is a pointer type, so the first argument -# is a pointer to a pointer -_c_ncclCommInitRank = nccl.ncclCommInitRank -_c_ncclCommInitRank.restype = ctypes.c_int -_c_ncclCommInitRank.argtypes = [ - ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int -] - -ncclDataType_t = ctypes.c_int - - -class ncclDataTypeEnum: - ncclInt8 = 0 - ncclChar = 0 - ncclUint8 = 1 - ncclInt32 = 2 - ncclInt = 2 - ncclUint32 = 3 - ncclInt64 = 4 - ncclUint64 = 5 - ncclFloat16 = 6 - ncclHalf = 6 - ncclFloat32 = 7 - ncclFloat = 7 - ncclFloat64 = 8 - ncclDouble = 8 - ncclBfloat16 = 9 - ncclNumTypes = 10 - @classmethod - def from_torch(cls, dtype: torch.dtype) -> int: - if dtype == torch.int8: - return cls.ncclInt8 - if dtype == torch.uint8: - return cls.ncclUint8 - if dtype == torch.int32: - return cls.ncclInt32 - if dtype == torch.int64: - return cls.ncclInt64 - if dtype == torch.float16: - return cls.ncclFloat16 - if dtype == torch.float32: - return cls.ncclFloat32 - if dtype == torch.float64: - return cls.ncclFloat64 - if dtype == torch.bfloat16: - return cls.ncclBfloat16 - raise ValueError(f"Unsupported dtype: {dtype}") - - -ncclRedOp_t = ctypes.c_int - - -class ncclRedOpTypeEnum: - ncclSum = 0 - ncclProd = 1 - ncclMax = 2 - ncclMin = 3 - ncclAvg = 4 - ncclNumOps = 5 - - @classmethod - def from_torch(cls, op: ReduceOp) -> int: - if op == ReduceOp.SUM: - return cls.ncclSum - if op == ReduceOp.PRODUCT: - return cls.ncclProd - if op == ReduceOp.MAX: - return cls.ncclMax - if op == ReduceOp.MIN: - return cls.ncclMin - if op == ReduceOp.AVG: - return cls.ncclAvg - raise ValueError(f"Unsupported op: {op}") - - -# equivalent to c declaration: -# ncclResult_t ncclAllReduce( -# const void* sendbuff, void* recvbuff, size_t count, -# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, -# udaStream_t stream); -# note that cudaStream_t is a pointer type, so the last argument is a pointer -_c_ncclAllReduce = nccl.ncclAllReduce -_c_ncclAllReduce.restype = ctypes.c_int -_c_ncclAllReduce.argtypes = [ - ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t, - ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p -] - -# be cautious! this is a collective call, it will block until all -# processes in the communicator have called this function. -# because Python object destruction can happen in random order, -# it is better not to call it at all. -# equivalent to c declaration: -# ncclResult_t ncclCommDestroy(ncclComm_t comm); -_c_ncclCommDestroy = nccl.ncclCommDestroy -_c_ncclCommDestroy.restype = ctypes.c_int -_c_ncclCommDestroy.argtypes = [ctypes.c_void_p] - - -class NCCLCommunicator: +class PyNcclCommunicator: def __init__( self, group: Optional[ProcessGroup] = None, device: Optional[Union[int, str, torch.device]] = None, + library_path: Optional[str] = None, ): """ Args: group: the process group to work on. If None, it will use the default process group. - device: the device to bind the NCCLCommunicator to. If None, + device: the device to bind the PyNcclCommunicator to. If None, it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. It is the caller's responsibility to make sure each communicator is bind to a unique device. """ assert dist.is_initialized() group = get_cpu_world_group() if group is None else group assert dist.get_backend(group) != dist.Backend.NCCL, ( - "NCCLCommunicator should be attached to a non-NCCL group.") + "PyNcclCommunicator should be attached to a non-NCCL group.") self.group = group # note: this rank is the rank in the group self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(group) + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + self.stream = None + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + self.stream = None + return + + self.available = True + self.disabled = False + + logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) + if self.rank == 0: - self.unique_id = ncclGetUniqueId() + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() else: - self.unique_id = NcclUniqueId() + # construct an empty unique id + self.unique_id = ncclUniqueId() tensor = torch.ByteTensor(list(self.unique_id.internal)) ranks = dist.get_process_group_ranks(group) # arg `src` in `broadcast` is the global rank @@ -246,7 +77,6 @@ def __init__( byte_list = tensor.tolist() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte - self.comm = ctypes.c_void_p() if device is None: local_rank = get_local_rank() device = torch.device(f"cuda:{local_rank}") @@ -261,15 +91,25 @@ def __init__( # `torch.cuda.device` is a context manager that changes the # current cuda device to the specified one with torch.cuda.device(device): - NCCL_CHECK( - _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, - self.unique_id, self.rank)) + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank) self.stream = torch.cuda.Stream() + # A small all_reduce for warmup. + self.all_reduce(torch.zeros(1, device=device)) + self.stream.synchronize() + + # by default it is disabled, e.g. in profiling models and prefill phase. + # to use it, use under `with obj.change_state(enable=True)`, usually + # when we are using CUDA graph. + self.disabled = True + def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None): + if self.disabled: + return # nccl communicator created on a specific device # will only work on tensors on the same device # otherwise it will cause "illegal memory access" @@ -278,10 +118,32 @@ def all_reduce(self, f"but the input tensor is on {tensor.device}") if stream is None: stream = self.stream - NCCL_CHECK( - _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), - ctypes.c_void_p(tensor.data_ptr()), - tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - ctypes.c_void_p(stream.cuda_stream))) + self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()), + buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + + @contextmanager + def change_state(self, + enable: Optional[bool] = None, + stream: Optional[torch.cuda.Stream] = None): + """ + A context manager to change the state of the communicator. + """ + if enable is None: + # guess a default value when not specified + enable = self.available + + if stream is None: + stream = self.stream + + old_disable = self.disabled + old_stream = self.stream + + self.stream = stream + self.disabled = not enable + yield + + self.disabled = old_disable + self.stream = old_stream diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py deleted file mode 100644 index 44e4f39217a41..0000000000000 --- a/vllm/distributed/device_communicators/pynccl_utils.py +++ /dev/null @@ -1,66 +0,0 @@ -import contextlib -from typing import Optional - -import torch -from torch.distributed import ProcessGroup, ReduceOp - -from vllm.logger import init_logger - -logger = init_logger(__name__) - -try: - from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, - ncclGetVersion) -except Exception as e: - # in non-NVIDIA environments, we can't import the nccl module - # e.g. when running on machines with AMD GPUs - logger.info("Failed to import NCCL library: %s", e) - logger.info("It is expected if you are not running on NVIDIA GPUs.") - pass - -comm: Optional["NCCLCommunicator"] = None - - -def is_initialized() -> bool: - """Returns whether the NCCL backend is initialized.""" - return comm is not None - - -@contextlib.contextmanager -def set_pynccl_stream(stream: torch.cuda.Stream): - """Set the cuda stream for communication""" - try: - assert comm is not None - comm.stream = stream - yield - finally: - pass - - -def init_process_group(group: Optional[ProcessGroup] = None) -> None: - assert not is_initialized() - global comm - logger.info("vLLM is using nccl==%s", ncclGetVersion()) - comm = NCCLCommunicator(group=group) - - -def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: - """All-reduces the input tensor across the process group.""" - assert input_.is_cuda, f"{input_} should be a cuda tensor" - assert comm is not None - comm.all_reduce(input_, op) - - -def destroy_process_group() -> None: - global comm - comm = None - - -def get_world_size() -> int: - """Returns the world size.""" - assert comm is not None - return comm.world_size - - -def get_nccl_backend() -> Optional["NCCLCommunicator"]: - return comm diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py new file mode 100644 index 0000000000000..43d85674b23d0 --- /dev/null +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -0,0 +1,258 @@ +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp + +from vllm.logger import init_logger +from vllm.utils import find_nccl_library, nccl_integrity_check + +logger = init_logger(__name__) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, + [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("ncclCommInitRank", ncclResult_t, [ + ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, + ctypes.c_int + ]), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + # load the library in another process. + # if it core dumps, it will not crash the current process + nccl_integrity_check(so_file) + except Exception as e: + logger.error( + "Failed to load NCCL library from %s ." + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s." + "One solution is to download libnccl2 version 2.18 from " + "https://developer.download.nvidia.com/compute/cuda/repos/ " + "and extract the libnccl.so.2 file. If you already have the " + "library, please set the environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, + rank: int) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", + "ncclComm_t", "cudaStream_t", "buffer_type" +] diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index be5bb4e857caf..5075da11bb1b8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -3,10 +3,10 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" -import contextlib -from typing import Optional +from typing import List, Optional import torch +from torch.distributed import ProcessGroup import vllm.envs as envs from vllm.logger import init_logger @@ -14,10 +14,11 @@ logger = init_logger(__name__) # Tensor model parallel group that the current rank belongs to. -_TP_DEVICE_GROUP = None -_TP_CPU_GROUP = None +_TP_DEVICE_GROUP: Optional[ProcessGroup] = None +_TP_CPU_GROUP: Optional[ProcessGroup] = None +_TP_PYNCCL_COMMUNICATOR = None # Pipeline model parallel group that the current rank belongs to. -_PIPELINE_MODEL_PARALLEL_GROUP = None +_PP_DEVICE_GROUP: Optional[ProcessGroup] = None # when people blindly call `torch.distributed.all_reduce` etc, # it will use this group. It is initialized with the `backend` @@ -41,11 +42,16 @@ # A list of global ranks for each pipeline group to ease calculation of the # source rank when broadcasting from the first or last pipeline stage. -_PIPELINE_GLOBAL_RANKS = None +_PP_GLOBAL_RANKS: Optional[List[int]] = None _LOCAL_RANK = -1 +def get_tp_pynccl_communicator(): + global _TP_PYNCCL_COMMUNICATOR + return _TP_PYNCCL_COMMUNICATOR + + def get_local_rank(): global _LOCAL_RANK return _LOCAL_RANK @@ -80,10 +86,20 @@ def init_distributed_environment( # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 - if local_rank == -1 and distributed_init_method == "env://": - local_rank = envs.LOCAL_RANK + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank global _LOCAL_RANK _LOCAL_RANK = local_rank + # A small all_reduce for warmup. + data = torch.zeros(1) + if torch.cuda.is_available(): + data = data.to(device=f"cuda:{local_rank}") + torch.distributed.all_reduce(data) def initialize_model_parallel( @@ -133,29 +149,36 @@ def initialize_model_parallel( rank = torch.distributed.get_rank() # Build the tensor model-parallel groups. - global _TP_DEVICE_GROUP, _TP_CPU_GROUP + global _TP_DEVICE_GROUP, _TP_CPU_GROUP, _TP_PYNCCL_COMMUNICATOR assert _TP_DEVICE_GROUP is None, ( "tensor model parallel group is already initialized") for i in range(num_tensor_model_parallel_groups): - ranks = range(i * tensor_model_parallel_size, - (i + 1) * tensor_model_parallel_size) + ranks = list( + range(i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size)) group = torch.distributed.new_group(ranks, backend=backend) cpu_group = torch.distributed.new_group(ranks, backend="gloo") if rank in ranks: _TP_DEVICE_GROUP = group _TP_CPU_GROUP = cpu_group + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( + group=_TP_CPU_GROUP, + device=_LOCAL_RANK, + ) + # Build the pipeline model-parallel groups. - global _PIPELINE_MODEL_PARALLEL_GROUP - global _PIPELINE_GLOBAL_RANKS - assert _PIPELINE_MODEL_PARALLEL_GROUP is None, ( + global _PP_DEVICE_GROUP + global _PP_GLOBAL_RANKS + assert _PP_DEVICE_GROUP is None, ( "pipeline model parallel group is already initialized") for i in range(num_pipeline_model_parallel_groups): - ranks = range(i, world_size, num_pipeline_model_parallel_groups) + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group = torch.distributed.new_group(ranks, backend=backend) if rank in ranks: - _PIPELINE_MODEL_PARALLEL_GROUP = group - _PIPELINE_GLOBAL_RANKS = ranks + _PP_DEVICE_GROUP = group + _PP_GLOBAL_RANKS = ranks def ensure_model_parallel_initialized( @@ -188,8 +211,7 @@ def ensure_model_parallel_initialized( def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return (_TP_DEVICE_GROUP is not None - and _PIPELINE_MODEL_PARALLEL_GROUP is not None) + return (_TP_DEVICE_GROUP is not None and _PP_DEVICE_GROUP is not None) def get_cpu_world_group(): @@ -214,9 +236,9 @@ def get_tensor_model_parallel_cpu_group(): def get_pipeline_model_parallel_group(): """Get the pipeline model parallel group the caller rank belongs to.""" - assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, ( + assert _PP_DEVICE_GROUP is not None, ( "pipeline model parallel group is not initialized") - return _PIPELINE_MODEL_PARALLEL_GROUP + return _PP_DEVICE_GROUP def get_tensor_model_parallel_world_size(): @@ -253,36 +275,36 @@ def get_tensor_model_parallel_src_rank(): def get_pipeline_model_parallel_first_rank(): """Return the global rank of the first process in the pipeline for the current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") - return _PIPELINE_GLOBAL_RANKS[0] + return _PP_GLOBAL_RANKS[0] def get_pipeline_model_parallel_last_rank(): """Return the global rank of the last process in the pipeline for the current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") last_rank_local = get_pipeline_model_parallel_world_size() - 1 - return _PIPELINE_GLOBAL_RANKS[last_rank_local] + return _PP_GLOBAL_RANKS[last_rank_local] def get_pipeline_model_parallel_next_rank(): """Return the global rank that follows the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] + return _PP_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] def get_pipeline_model_parallel_prev_rank(): """Return the global rank that precedes the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] + return _PP_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] def destroy_model_parallel(): @@ -295,45 +317,12 @@ def destroy_model_parallel(): if _TP_CPU_GROUP: torch.distributed.destroy_process_group(_TP_CPU_GROUP) _TP_CPU_GROUP = None - global _PIPELINE_MODEL_PARALLEL_GROUP - if _PIPELINE_MODEL_PARALLEL_GROUP: - torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP) - _PIPELINE_MODEL_PARALLEL_GROUP = None - global _PIPELINE_GLOBAL_RANKS - _PIPELINE_GLOBAL_RANKS = None - from vllm.distributed.device_communicators import pynccl_utils - - # Destroy the pynccl states if any. - pynccl_utils.destroy_process_group() - - -# Whether to use pynccl for nccl all reduce. -# We use pynccl for all reduce when using CUDA graph, because torch.distributed -# is not well supported by CUDA graph. -_ENABLE_PYNCCL_FOR_ALL_REDUCE = False - - -@contextlib.contextmanager -def with_pynccl_for_all_reduce(): - from vllm.distributed.device_communicators import pynccl_utils - """use pynccl instead of torch.distributed for all reduce""" - tp_size = get_tensor_model_parallel_world_size() - if tp_size == 1: - # No-op. - # NOTE(woosuk): We don't initialize pynccl when tp_size is 1. - yield - else: - global _ENABLE_PYNCCL_FOR_ALL_REDUCE - old = _ENABLE_PYNCCL_FOR_ALL_REDUCE - _ENABLE_PYNCCL_FOR_ALL_REDUCE = True - - stream = torch.cuda.current_stream() - with pynccl_utils.set_pynccl_stream(stream): - yield - _ENABLE_PYNCCL_FOR_ALL_REDUCE = old - - -def is_pynccl_enabled_for_all_reduce(): - """check if pynccl is enabled for all reduce""" - global _ENABLE_PYNCCL_FOR_ALL_REDUCE - return _ENABLE_PYNCCL_FOR_ALL_REDUCE + global _TP_PYNCCL_COMMUNICATOR + _TP_PYNCCL_COMMUNICATOR = None + + global _PP_DEVICE_GROUP + if _PP_DEVICE_GROUP: + torch.distributed.destroy_process_group(_PP_DEVICE_GROUP) + _PP_DEVICE_GROUP = None + global _PP_GLOBAL_RANKS + _PP_GLOBAL_RANKS = None diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b5e582116297c..3fc76c6142165 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,4 +1,3 @@ -import contextlib import time from enum import IntEnum from typing import Dict, List, NamedTuple, Optional, Set, Tuple @@ -12,9 +11,9 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce -from vllm.distributed.device_communicators import (custom_all_reduce, - pynccl_utils) +from vllm.distributed import broadcast_tensor_dict +from vllm.distributed.communication_op import graph_capture_mode +from vllm.distributed.device_communicators import custom_all_reduce from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -917,10 +916,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: Since it is used for decoding-only, it assumes there's only 1 token per sequence in the batch. """ - # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never - # deleted before the CUDA graphs. - self.pynccl_backend = pynccl_utils.get_nccl_backend() - assert not self.model_config.enforce_eager logger.info("Capturing the model for CUDA graphs. This may lead to " "unexpected consequences if the model is not static. To " @@ -1046,7 +1041,7 @@ def capture( # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with _maybe_pynccl(): + with graph_capture_mode(): self.model( input_ids, positions, @@ -1061,7 +1056,7 @@ def capture( # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117 - with _maybe_pynccl(): + with graph_capture_mode(): hidden_states = self.model( input_ids, positions, @@ -1113,16 +1108,6 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) -@contextlib.contextmanager -def _maybe_pynccl(): - if pynccl_utils.is_initialized( - ) and not custom_all_reduce.is_initialized(): - with with_pynccl_for_all_reduce(): - yield - else: - yield - - def _get_graph_batch_size(batch_size: int) -> int: """Returns the padded batch size given actual batch size. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 43f6b2b443b70..0ca9c2b64cf30 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -11,9 +11,7 @@ VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, - get_tensor_model_parallel_cpu_group, init_distributed_environment) -from vllm.distributed.device_communicators import pynccl_utils from vllm.distributed.device_communicators.custom_all_reduce import ( init_custom_ar) from vllm.lora.request import LoRARequest @@ -306,29 +304,10 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - if pynccl_utils.is_initialized(): - pynccl_world_size = pynccl_utils.get_world_size() - if pynccl_world_size != parallel_config.world_size: - raise RuntimeError( - "pynccl is already initialized but the pynccl world " - "size does not match parallel_config.world_size " - f"({pynccl_world_size} vs. {parallel_config.world_size}).") - elif parallel_config.world_size > 1: - # NOTE(woosuk): We don't initialize pynccl process group when world size - # is 1. - # NOTE(kaichao): By default, pynccl is initialized for tp group. - pynccl_utils.init_process_group( - group=get_tensor_model_parallel_cpu_group()) - # Initialize a custom fast all-reduce implementation. if not parallel_config.disable_custom_all_reduce: init_custom_ar() - # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cuda()) - if pynccl_utils.is_initialized(): - pynccl_utils.all_reduce(torch.zeros(1).cuda()) - def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. From e965d4618430db94fed73d4305c44c9147e47150 Mon Sep 17 00:00:00 2001 From: "Allen.Dou" Date: Fri, 10 May 2024 12:42:38 +0800 Subject: [PATCH 018/124] [Misc] Keep only one implementation of the create_dummy_prompt function. (#4716) --- tests/test_sequence.py | 36 ++++-------------------------------- 1 file changed, 4 insertions(+), 32 deletions(-) diff --git a/tests/test_sequence.py b/tests/test_sequence.py index b16bdc141e57c..53061278d5be4 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,36 +1,8 @@ -import time -from typing import Optional - import pytest -from vllm import SamplingParams -from vllm.lora.request import LoRARequest -from vllm.sequence import (SamplerOutput, Sequence, SequenceData, - SequenceGroup, SequenceGroupOutput, SequenceOutput) - - -def create_dummy_prompt( - request_id: str, - prompt_length: int, - block_size: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - use_beam_search: bool = False, - best_of: int = 1, -) -> SequenceGroup: - if not block_size: - block_size = prompt_length - - # Create dummy prompt sequence with tokens 0...block_size-1 - # and prompt "0 ... block_size". - prompt_tokens = list(range(prompt_length)) - prompt_str = " ".join([str(t) for t in prompt_tokens]) - prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) - seq_group = SequenceGroup( - request_id, [prompt], - SamplingParams(use_beam_search=use_beam_search, best_of=best_of), - time.time(), lora_request) - - return seq_group +from tests.core.utils import create_dummy_prompt +from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput, + SequenceOutput) @pytest.fixture @@ -102,7 +74,7 @@ def test_sequence_data_prefill(): def test_sequence_group_stage(): - seq_group = create_dummy_prompt("1", 12) + _, seq_group = create_dummy_prompt("1", 12) assert seq_group.is_prefill() is True seq_group.update_num_computed_tokens(6) assert seq_group.is_prefill() is True From 51d4094fda63b1d738f55ae9dd75d354b9c1143c Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 9 May 2024 22:13:23 -0700 Subject: [PATCH 019/124] chunked-prefill-doc-syntax (#4603) Fix the docs: https://docs.vllm.ai/en/latest/models/performance.html Co-authored-by: sang --- docs/source/models/performance.rst | 36 +++++++++++++++++------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/docs/source/models/performance.rst b/docs/source/models/performance.rst index 067757699f32a..589fce21056c2 100644 --- a/docs/source/models/performance.rst +++ b/docs/source/models/performance.rst @@ -7,7 +7,7 @@ Chunked Prefill --------------- vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests. -You can enable the feature by specifying +You can enable the feature by specifying ``--enable-chunked-prefill`` in the command line or setting ``enable_chunked_prefill=True`` in the LLM constructor. .. code-block:: python @@ -16,23 +16,29 @@ You can enable the feature by specifying # NOTE: 512 is the default max_num_batched_tokens for chunked prefill. # llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True, max_num_batched_tokens=512) -By default, vLLM scheduler prioritizes prefills and doesn't batch prefill and decode to the same batch. This policy optimizes the TTFT (time to thefirst token), but incurs slower ITL (inter token latency) and inefficient GPU utilization. +By default, vLLM scheduler prioritizes prefills and doesn't batch prefill and decode to the same batch. +This policy optimizes the TTFT (time to the first token), but incurs slower ITL (inter token latency) and inefficient GPU utilization. -Once chunked prefill is enabled, the policy is changed to +Once chunked prefill is enabled, the policy is changed to prioritize decode requests. +It batches all pending decode requests to the batch before scheduling any prefill. +When there are available token_budget (``max_num_batched_tokens``), it schedules pending prefills. +If a last pending prefill request cannot fit into ``max_num_batched_tokens``, it chunks it. -- prioritize decode requests. It batches all pending decode requests to the batch before scheduling any prefill. -- When there are available token_budget (`max_num_batched_tokens`), it schedules pending prefills. If a last pending prefill request cannot fit into `max_num_batched_tokens`, it chunks it. +This policy has two benefits: -This policy has two benefits. - -- It improves ITL (inter token latency) and generation decode because decode requests are prioritized. +- It improves ITL and generation decode because decode requests are prioritized. - It helps achieve better GPU utilization by locating compute-bound (prefill) and memory-bound (decode) requests to the same batch. -You can tune the performance by changing `max_num_batched_tokens`. -By default, it is set to 512, which has the best ITL on A100 in the initial benchmark. -Smaller batch size achieves better ITL because there are fewer prefills interrupting decodes. -Higher batch size achieves better TTFT as you can put more prefill to the batch. -If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the default scheduling policy (except that it still prioritizes decodes). -Note that the default batch size (512) is optimized for ITL, and it may have lower throughput than the default scheduler. We recommend you set `max_num_batched_tokens > 2048` for throughput. +You can tune the performance by changing ``max_num_batched_tokens``. +By default, it is set to 512, which has the best ITL on A100 in the initial benchmark (llama 70B and mixtral 8x22B). +Smaller ``max_num_batched_tokens`` achieves better ITL because there are fewer prefills interrupting decodes. +Higher ``max_num_batched_tokens`` achieves better TTFT as you can put more prefill to the batch. + +- If ``max_num_batched_tokens`` is the same as ``max_model_len``, that's almost the equivalent to the default scheduling policy (except that it still prioritizes decodes). +- Note that the default value (512) of ``max_num_batched_tokens`` is optimized for ITL, and it may have lower throughput than the default scheduler. + +We recommend you set ``max_num_batched_tokens > 2048`` for throughput. + +See related papers for more details (https://arxiv.org/pdf/2401.08671 or https://arxiv.org/pdf/2308.16369). -See related papers for more details (https://arxiv.org/pdf/2401.08671 or https://arxiv.org/pdf/2308.16369). +Please try out this feature and let us know your feedback via GitHub issues! \ No newline at end of file From 64b77dfd7e1378853ec7b189f3d7d0e51ce18855 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Fri, 10 May 2024 20:52:48 +0800 Subject: [PATCH 020/124] [Core]fix type annotation for `swap_blocks` (#4726) --- vllm/_custom_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 35a9f6329fc42..42dedfdf76c4f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple import torch @@ -266,7 +266,7 @@ def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, def swap_blocks(src: torch.Tensor, dst: torch.Tensor, - block_mapping: Dict[int, int]) -> None: + block_mapping: torch.Tensor) -> None: vllm_cache_ops.swap_blocks(src, dst, block_mapping) From dac6a3f6ed14ea4061b672f9290bfdf8bcdd996d Mon Sep 17 00:00:00 2001 From: Steve Grubb Date: Fri, 10 May 2024 09:37:05 -0400 Subject: [PATCH 021/124] [Misc] Apply a couple g++ cleanups (#4719) --- csrc/cpu/cache.cpp | 2 +- csrc/cpu/pos_encoding.cpp | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 620d11ef1ed6c..26e81685d623e 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -84,7 +84,7 @@ void reshape_and_cache_cpu_impl( void copy_blocks(std::vector &key_caches, std::vector &value_caches, const torch::Tensor& block_mapping) { - int num_layers = key_caches.size(); + unsigned num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { return; diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index e9b3992204bb2..5dc1bde45ac5f 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -19,7 +19,6 @@ void rotary_embedding_impl( const int num_tokens) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); - constexpr int ELEM_SIZE = sizeof(scalar_t); const int embed_dim = rot_dim / 2; TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); From 6a0f617210dfba76f3db4db1155d1f1489609133 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 10 May 2024 23:54:32 +0900 Subject: [PATCH 022/124] [Core] Fix circular reference which leaked llm instance in local dev env (#4737) Storing exception frame is extremely prone to circular refernece because it contains the reference to objects. When tensorizer is not installed, it leaks llm instance because error frame has references to various modules which cause circular reference problem. I also found spec decoding has a circular reference issue, and I solved it using weakref.proxy. --- tests/basic_correctness/test_basic_correctness.py | 13 +++++++++++++ vllm/model_executor/model_loader/tensorizer.py | 10 +++++----- vllm/spec_decode/multi_step_worker.py | 3 ++- vllm/spec_decode/ngram_worker.py | 3 ++- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index d75279dd9cfa9..7d8117447ca0a 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -3,9 +3,12 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ import os +import weakref import pytest +from vllm import LLM + MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", @@ -13,6 +16,16 @@ VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" +def test_vllm_gc_ed(): + """Verify vllm instance is GC'ed when it is deleted""" + llm = LLM("facebook/opt-125m") + weak_llm = weakref.ref(llm) + del llm + # If there's any circular reference to vllm, this fails + # because llm instance is not GC'ed. + assert weak_llm() is None + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index af433b86e604d..219a2a392e129 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -19,7 +19,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -tensorizer_load_fail = None +tensorizer_error_msg = None try: from tensorizer import (DecryptionParams, EncryptionParams, @@ -28,7 +28,7 @@ from tensorizer.utils import (convert_bytes, get_mem_usage, no_init_or_tensor) except ImportError as e: - tensorizer_load_fail = e + tensorizer_error_msg = str(e) __all__ = [ 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', @@ -254,11 +254,11 @@ class TensorizerAgent: def __init__(self, tensorizer_config: TensorizerConfig, quant_config: QuantizationConfig, **extra_kwargs): - if tensorizer_load_fail is not None: + if tensorizer_error_msg is not None: raise ImportError( "Tensorizer is not installed. Please install tensorizer " - "to use this feature with `pip install vllm[tensorizer]`." - ) from tensorizer_load_fail + "to use this feature with `pip install vllm[tensorizer]`. " + "Error message: {}".format(tensorizer_error_msg)) self.tensorizer_config = tensorizer_config self.tensorizer_args = ( diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 5044cc1ef85fd..20098ebaeea32 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -1,4 +1,5 @@ import copy +import weakref from typing import List, Tuple import torch @@ -32,7 +33,7 @@ def init_device(self): super().init_device() self._proposer = Top1Proposer( - self, + weakref.proxy(self), self.device, self.vocab_size, max_proposal_len=self.max_model_len, diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index f18f9387f5b23..6cd50fcc1a041 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -1,3 +1,4 @@ +import weakref from typing import List, Optional, Tuple import torch @@ -37,7 +38,7 @@ def init_device(self): # Current only support Top1Proposer self._proposer = Top1Proposer( - self, + weakref.proxy(self), device=self.device, vocab_size=self.vocab_size, ) From 706588a77d2099b118f53a53ef2dd7f8c2de9ffc Mon Sep 17 00:00:00 2001 From: "Allen.Dou" Date: Fri, 10 May 2024 23:00:56 +0800 Subject: [PATCH 023/124] [Bugfix] Fix CLI arguments in OpenAI server docs (#4729) --- docs/requirements-docs.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 0e76763a87b7c..ed569816200ee 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -10,3 +10,4 @@ pydantic torch py-cpuinfo transformers +openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args From 2e7796f2cf4537dd3b08d5de3aa8349f5db1a168 Mon Sep 17 00:00:00 2001 From: heeju-kim2 <157340754+heeju-kim2@users.noreply.github.com> Date: Sat, 11 May 2024 02:36:25 +0900 Subject: [PATCH 024/124] [Speculative decoding] CUDA graph support (#4295) Co-authored-by: Cade Daniel --- .../e2e/test_multistep_correctness.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 94d71fb012727..d2da039e84c07 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -611,3 +611,40 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, batch_size, max_output_len=output_len, force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Required for spec decode. + "use_v2_block_manager": True, + + # Verify equality when cuda graphs allowed. + "enforce_eager": False, + "model": "JackFram/llama-68m", + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + # Identical models. + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("output_len", [32]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator, + batch_size, output_len): + """Verify spec decode equality when cuda graphs are enabled. + """ + run_greedy_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + ) From fcc2994be657a897dd0732928754749048520b28 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Fri, 10 May 2024 16:01:01 -0600 Subject: [PATCH 025/124] [CI] Nits for bad initialization of SeqGroup in testing (#4748) --- tests/core/test_block_manager.py | 13 +++++++++---- tests/core/utils.py | 11 +++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 9db58e075196d..22a9f0cf47d32 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -142,8 +142,10 @@ def test_append_slot_cow(): child = prompt.fork(new_seq_id=2) # Allocate space for the sequence group. - seq_group = SequenceGroup("1", [prompt, child], SamplingParams(), - time.time(), time.perf_counter) + seq_group = SequenceGroup(request_id="1", + seqs=[prompt, child], + arrival_time=time.time(), + sampling_params=SamplingParams()) block_manager.allocate(seq_group) # Fork and append a new token id. We expect a COW to be scheduled. @@ -303,8 +305,11 @@ def test_sliding_window_multi_seq(): assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks parent = Sequence(1, "one two three", [0, 1, 2], block_size) - seq_group = SequenceGroup("1", [parent], SamplingParams(), time.time(), - None) + seq_group = SequenceGroup(request_id="1", + seqs=[parent], + arrival_time=time.time(), + sampling_params=SamplingParams(), + lora_request=None) block_manager.allocate(seq_group) # assert the number of blocks allocated is correct diff --git a/tests/core/utils.py b/tests/core/utils.py index 22c1d3826dff4..8fb13177a2d6c 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -22,10 +22,13 @@ def create_dummy_prompt( prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) - seq_group = SequenceGroup( - request_id, [prompt], - SamplingParams(use_beam_search=use_beam_search, best_of=best_of), - time.time(), lora_request) + seq_group = SequenceGroup(request_id=request_id, + seqs=[prompt], + arrival_time=time.time(), + sampling_params=SamplingParams( + use_beam_search=use_beam_search, + best_of=best_of), + lora_request=lora_request) return prompt, seq_group From 4e12131089f192334f6e09c8fe5cd85af1e25327 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 May 2024 15:14:40 -0700 Subject: [PATCH 026/124] [Core][Test] fix function name typo in custom allreduce (#4750) --- tests/distributed/test_custom_all_reduce.py | 4 ++-- vllm/distributed/device_communicators/custom_all_reduce.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 3b1cd1773af19..308b874280f55 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -25,7 +25,7 @@ def graph_allreduce(world_size, rank, distributed_init_port): init_test_distributed_environment(1, world_size, rank, distributed_init_port) - custom_all_reduce.init_custom_all_reduce() + custom_all_reduce.init_custom_ar() for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: with custom_all_reduce.capture(): @@ -61,7 +61,7 @@ def eager_allreduce(world_size, rank, distributed_init_port): distributed_init_port) sz = 1024 - custom_all_reduce.init_custom_all_reduce() + custom_all_reduce.init_custom_ar() fa = custom_all_reduce.get_handle() inp = torch.ones(sz, dtype=torch.float32, device=device) out = fa.all_reduce_unreg(inp) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index cc5f8166877ce..5d26254fb832a 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -52,6 +52,10 @@ def init_custom_ar() -> None: "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" " is set.") return + + # we only use a subset of GPUs here + # so we only need to check the nvlink connectivity of these GPUs + num_dev = world_size # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES From e254497b66dcd87038969b0ad34d34425edfc5fe Mon Sep 17 00:00:00 2001 From: Chang Su Date: Sat, 11 May 2024 11:30:37 -0700 Subject: [PATCH 027/124] [Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734) --- examples/offline_inference_embedding.py | 17 ++ examples/openai_embedding_client.py | 23 ++ requirements-dev.txt | 9 +- tests/conftest.py | 38 ++- .../output_processor/test_multi_step.py | 12 +- tests/entrypoints/openai/test_serving_chat.py | 1 + tests/entrypoints/test_openai_server.py | 96 ++++++- tests/models/test_embedding.py | 44 +++ tests/samplers/test_logits_processor.py | 6 +- tests/samplers/test_seeded_generate.py | 2 +- tests/spec_decode/utils.py | 6 +- tests/test_sequence.py | 12 +- vllm/__init__.py | 7 +- vllm/config.py | 15 + vllm/core/embedding_model_block_manager.py | 84 ++++++ vllm/core/interfaces.py | 5 + vllm/core/scheduler.py | 10 +- vllm/engine/arg_utils.py | 1 + vllm/engine/async_llm_engine.py | 158 +++++++++-- vllm/engine/llm_engine.py | 143 ++++++++-- vllm/entrypoints/llm.py | 150 ++++++++-- vllm/entrypoints/openai/api_server.py | 20 +- vllm/entrypoints/openai/protocol.py | 36 ++- vllm/entrypoints/openai/serving_embedding.py | 134 +++++++++ vllm/entrypoints/openai/serving_engine.py | 16 +- vllm/executor/gpu_executor.py | 10 +- vllm/model_executor/layers/pooler.py | 56 ++++ vllm/model_executor/layers/sampler.py | 7 +- vllm/model_executor/models/__init__.py | 12 +- vllm/model_executor/models/llama_embedding.py | 87 ++++++ vllm/model_executor/pooling_metadata.py | 69 +++++ vllm/outputs.py | 82 +++++- vllm/pooling_params.py | 20 ++ vllm/sequence.py | 89 +++++- vllm/spec_decode/util.py | 5 +- vllm/worker/embedding_model_runner.py | 266 ++++++++++++++++++ vllm/worker/model_runner.py | 25 +- vllm/worker/worker.py | 14 +- 38 files changed, 1627 insertions(+), 160 deletions(-) create mode 100644 examples/offline_inference_embedding.py create mode 100644 examples/openai_embedding_client.py create mode 100644 tests/models/test_embedding.py create mode 100644 vllm/core/embedding_model_block_manager.py create mode 100644 vllm/entrypoints/openai/serving_embedding.py create mode 100644 vllm/model_executor/layers/pooler.py create mode 100644 vllm/model_executor/models/llama_embedding.py create mode 100644 vllm/model_executor/pooling_metadata.py create mode 100644 vllm/pooling_params.py create mode 100644 vllm/worker/embedding_model_runner.py diff --git a/examples/offline_inference_embedding.py b/examples/offline_inference_embedding.py new file mode 100644 index 0000000000000..7d5ef128bc8e0 --- /dev/null +++ b/examples/offline_inference_embedding.py @@ -0,0 +1,17 @@ +from vllm import LLM + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create an LLM. +model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True) +# Generate embedding. The output is a list of EmbeddingRequestOutputs. +outputs = model.encode(prompts) +# Print the outputs. +for output in outputs: + print(output.outputs.embedding) # list of 4096 floats diff --git a/examples/openai_embedding_client.py b/examples/openai_embedding_client.py new file mode 100644 index 0000000000000..b73360fe15a24 --- /dev/null +++ b/examples/openai_embedding_client.py @@ -0,0 +1,23 @@ +from openai import OpenAI + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + +responses = client.embeddings.create(input=[ + "Hello my name is", + "The best thing about vLLM is that it supports many different models" +], + model=model) + +for data in responses.data: + print(data.embedding) # list of float of len 4096 diff --git a/requirements-dev.txt b/requirements-dev.txt index e6d375cbafa39..796c9e37d0230 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -19,12 +19,15 @@ pytest-forked pytest-asyncio pytest-rerunfailures pytest-shard -httpx + +# testing utils +awscli einops # required for MPT +httpx +peft requests ray -peft -awscli +sentence-transformers # required for embedding # Benchmarking aiohttp diff --git a/tests/conftest.py b/tests/conftest.py index 1f2ad1cbd7298..b8117a19c75d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -133,6 +133,10 @@ def example_long_prompts() -> List[str]: "llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration, } +_EMBEDDING_MODELS = [ + "intfloat/e5-mistral-7b-instruct", +] + class HfRunner: @@ -145,14 +149,7 @@ def __init__( assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] self.model_name = model_name - if model_name not in _VISION_LANGUAGE_MODELS: - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ).cuda() - self.processor = None - else: + if model_name in _VISION_LANGUAGE_MODELS: self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained( model_name, torch_dtype=torch_dtype, @@ -162,6 +159,20 @@ def __init__( model_name, torch_dtype=torch_dtype, ) + elif model_name in _EMBEDDING_MODELS: + # Lazy init required for AMD CI + from sentence_transformers import SentenceTransformer + self.model = SentenceTransformer( + model_name, + device="cpu", + ).to(dtype=torch_dtype).cuda() + else: + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ).cuda() + self.processor = None if tokenizer_name is None: tokenizer_name = model_name self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True) @@ -334,6 +345,9 @@ def generate_greedy_logprobs_limit( return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + return self.model.encode(prompts) + def __del__(self): del self.model cleanup() @@ -459,6 +473,14 @@ def generate_beam_search( outputs = self.generate(prompts, beam_search_params) return outputs + def encode(self, prompts: List[str]) -> List[List[float]]: + req_outputs = self.model.encode(prompts) + outputs = [] + for req_output in req_outputs: + embedding = req_output.outputs.embedding + outputs.append(embedding) + return outputs + def __del__(self): del self.model cleanup() diff --git a/tests/engine/output_processor/test_multi_step.py b/tests/engine/output_processor/test_multi_step.py index 6da3da091db78..2bf4bf69da203 100644 --- a/tests/engine/output_processor/test_multi_step.py +++ b/tests/engine/output_processor/test_multi_step.py @@ -9,8 +9,8 @@ from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor from vllm.engine.output_processor.stop_checker import StopChecker from vllm.sampling_params import SamplingParams -from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput, - SequenceStatus) +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.utils import Counter @@ -51,7 +51,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): new_token_ids = list(range(num_new_tokens)) outputs = [ - SequenceGroupOutput( + CompletionSequenceGroupOutput( samples=[ SequenceOutput( parent_seq_id=seq.seq_id, @@ -103,7 +103,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, new_token_ids = list(range(num_new_tokens)) outputs = [ - SequenceGroupOutput( + CompletionSequenceGroupOutput( samples=[ SequenceOutput( parent_seq_id=seq.seq_id, @@ -170,7 +170,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, new_token_ids[eos_index] = eos_token_id outputs = [ - SequenceGroupOutput( + CompletionSequenceGroupOutput( samples=[ SequenceOutput( parent_seq_id=seq.seq_id, @@ -239,7 +239,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, new_token_ids[eos_index] = eos_token_id outputs = [ - SequenceGroupOutput( + CompletionSequenceGroupOutput( samples=[ SequenceOutput( parent_seq_id=seq.seq_id, diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 13e2e372cef33..74b49726734b5 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -14,6 +14,7 @@ class MockModelConfig: tokenizer_mode = "auto" max_model_len = 100 tokenizer_revision = None + embedding_mode = False @dataclass diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index e53e64a0c1ff8..c22ac4507658b 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -23,6 +23,7 @@ MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct" # technically this needs Mistral-7B-v0.1 as base, but we're not testing # generation quality here LORA_NAME = "typeof/zephyr-7b-beta-lora" @@ -121,7 +122,7 @@ def zephyr_lora_files(): return snapshot_download(repo_id=LORA_NAME) -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def server(zephyr_lora_files): ray.init() server_runner = ServerRunner.remote([ @@ -150,6 +151,25 @@ def server(zephyr_lora_files): ray.shutdown() +@pytest.fixture(scope="module") +def embedding_server(zephyr_lora_files): + ray.shutdown() + ray.init() + server_runner = ServerRunner.remote([ + "--model", + EMBEDDING_MODEL_NAME, + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + ]) + ray.get(server_runner.ready.remote()) + yield server_runner + ray.shutdown() + + @pytest.fixture(scope="module") def client(): client = openai.AsyncOpenAI( @@ -890,5 +910,79 @@ async def test_long_seed(server, client: openai.AsyncOpenAI): or "less_than_equal" in exc_info.value.message) +@pytest.mark.parametrize( + "model_name", + [EMBEDDING_MODEL_NAME], +) +async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI, + model_name: str): + input = [ + "The chef prepared a delicious meal.", + ] + + # test single embedding + embeddings = await client.embeddings.create( + model=model_name, + input=input, + encoding_format="float", + ) + assert embeddings.id is not None + assert embeddings.data is not None and len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding) == 4096 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens == 9 + assert embeddings.usage.total_tokens == 9 + + # test using token IDs + input = [1, 1, 1, 1, 1] + embeddings = await client.embeddings.create( + model=model_name, + input=input, + encoding_format="float", + ) + assert embeddings.id is not None + assert embeddings.data is not None and len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding) == 4096 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens == 5 + assert embeddings.usage.total_tokens == 5 + + +@pytest.mark.parametrize( + "model_name", + [EMBEDDING_MODEL_NAME], +) +async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI, + model_name: str): + # test List[str] + inputs = [ + "The cat sat on the mat.", "A feline was resting on a rug.", + "Stars twinkle brightly in the night sky." + ] + embeddings = await client.embeddings.create( + model=model_name, + input=inputs, + encoding_format="float", + ) + assert embeddings.id is not None + assert embeddings.data is not None and len(embeddings.data) == 3 + assert len(embeddings.data[0].embedding) == 4096 + + # test List[List[int]] + inputs = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], + [25, 32, 64, 77]] + embeddings = await client.embeddings.create( + model=model_name, + input=inputs, + encoding_format="float", + ) + assert embeddings.id is not None + assert embeddings.data is not None and len(embeddings.data) == 4 + assert len(embeddings.data[0].embedding) == 4096 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens == 17 + assert embeddings.usage.total_tokens == 17 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/models/test_embedding.py b/tests/models/test_embedding.py new file mode 100644 index 0000000000000..59bf054913f7c --- /dev/null +++ b/tests/models/test_embedding.py @@ -0,0 +1,44 @@ +"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling. + +Run `pytest tests/models/test_llama_embedding.py`. +""" +import pytest +import torch +import torch.nn.functional as F + +MODELS = [ + "intfloat/e5-mistral-7b-instruct", +] + + +def compare_embeddings(embeddings1, embeddings2): + similarities = [ + F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0) + for e1, e2 in zip(embeddings1, embeddings2) + ] + return similarities + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.encode(example_prompts) + del hf_model + + vllm_model = vllm_runner(model, dtype=dtype) + vllm_outputs = vllm_model.encode(example_prompts) + del vllm_model + + similarities = compare_embeddings(hf_outputs, vllm_outputs) + all_similarities = torch.stack(similarities) + tolerance = 1e-2 + assert torch.all((all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"Not all values are within {tolerance} of 1.0" diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py index 3788e9e9752ff..be4c2ea1b7810 100644 --- a/tests/samplers/test_logits_processor.py +++ b/tests/samplers/test_logits_processor.py @@ -36,14 +36,14 @@ def pick_vllm(token_ids, logits): # test logits_processors when prompt_logprobs is not None vllm_model.model._add_request( prompt=example_prompts[0], - sampling_params=params_with_logprobs, + params=params_with_logprobs, prompt_token_ids=None, ) # test prompt_logprobs is not None vllm_model.model._add_request( prompt=example_prompts[1], - sampling_params=SamplingParams( + params=SamplingParams( prompt_logprobs=3, max_tokens=max_tokens, ), @@ -53,7 +53,7 @@ def pick_vllm(token_ids, logits): # test grouped requests vllm_model.model._add_request( prompt=example_prompts[2], - sampling_params=SamplingParams(max_tokens=max_tokens), + params=SamplingParams(max_tokens=max_tokens), prompt_token_ids=None, ) diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index 3cd659cef58da..ce4501bbf71e5 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -60,7 +60,7 @@ def test_random_sample_with_seed( llm._add_request( prompt=prompt, prompt_token_ids=None, - sampling_params=params, + params=params, ) results = llm._run_engine(use_tqdm=False) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index f288652d51556..d52b22c30bd43 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -7,8 +7,8 @@ from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.utils import set_random_seed from vllm.sampling_params import SamplingParams -from vllm.sequence import (Logprob, SamplerOutput, SequenceData, - SequenceGroupMetadata, SequenceGroupOutput, +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceOutput) from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.cache_engine import CacheEngine @@ -170,7 +170,7 @@ def create_sampler_output_list( return [ SamplerOutput(outputs=[ - SequenceGroupOutput( + CompletionSequenceGroupOutput( samples=[ SequenceOutput( output_token=token_id, diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 53061278d5be4..b8ea1f6b77200 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,17 +1,17 @@ import pytest from tests.core.utils import create_dummy_prompt -from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput, - SequenceOutput) +from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput, + SequenceData, SequenceOutput) @pytest.fixture def sample_outputs(): return [ - SequenceGroupOutput(samples=[ + CompletionSequenceGroupOutput(samples=[ SequenceOutput(parent_seq_id=0, output_token=i, logprobs={}) ], - prompt_logprobs=None) for i in range(5) + prompt_logprobs=None) for i in range(5) ] @@ -32,10 +32,10 @@ def test_sampler_output_getitem(sampler_output, sample_outputs): def test_sampler_output_setitem(sampler_output): - new_output = SequenceGroupOutput(samples=[ + new_output = CompletionSequenceGroupOutput(samples=[ SequenceOutput(parent_seq_id=0, output_token=99, logprobs={}) ], - prompt_logprobs=None) + prompt_logprobs=None) sampler_output[2] = new_output assert sampler_output[2] == new_output diff --git a/vllm/__init__.py b/vllm/__init__.py index 59810da3ca411..74674ca0d12af 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -6,7 +6,9 @@ from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster from vllm.model_executor.models import ModelRegistry -from vllm.outputs import CompletionOutput, RequestOutput +from vllm.outputs import (CompletionOutput, EmbeddingOutput, + EmbeddingRequestOutput, RequestOutput) +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams __version__ = "0.4.2" @@ -17,9 +19,12 @@ "SamplingParams", "RequestOutput", "CompletionOutput", + "EmbeddingOutput", + "EmbeddingRequestOutput", "LLMEngine", "EngineArgs", "AsyncLLMEngine", "AsyncEngineArgs", "initialize_ray_cluster", + "PoolingParams", ] diff --git a/vllm/config.py b/vllm/config.py index 275814d72e6c3..fab9cfbf41a2d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -9,6 +9,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, get_quantization_config) +from vllm.model_executor.models import ModelRegistry from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron @@ -22,6 +23,7 @@ logger = init_logger(__name__) _GB = 1 << 30 +_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 class ModelConfig: @@ -126,6 +128,7 @@ def __init__( served_model_name) if not self.skip_tokenizer_init: self._verify_tokenizer_mode() + self._verify_embedding_mode() self._verify_quantization() self._verify_cuda_graph() @@ -137,6 +140,11 @@ def _verify_tokenizer_mode(self) -> None: "either 'auto' or 'slow'.") self.tokenizer_mode = tokenizer_mode + def _verify_embedding_mode(self) -> None: + architectures = getattr(self.hf_config, "architectures", []) + self.embedding_mode = any( + ModelRegistry.is_embedding_model(arch) for arch in architectures) + def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] rocm_supported_quantization = ["gptq", "squeezellm"] @@ -591,6 +599,7 @@ class SchedulerConfig: prompt latency) before scheduling next prompt. enable_chunked_prefill: If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. + embedding_mode: Whether the running model is for embedding. """ def __init__( @@ -602,6 +611,7 @@ def __init__( num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, + embedding_mode: Optional[bool] = False, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -610,6 +620,10 @@ def __init__( # It is the values that have the best balance between ITL # and TTFT on A100. Note it is not optimized for throughput. self.max_num_batched_tokens = 512 + elif embedding_mode: + # For embedding, choose specific value for higher throughput + self.max_num_batched_tokens = max( + max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS) else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. @@ -623,6 +637,7 @@ def __init__( self.num_lookahead_slots = num_lookahead_slots self.delay_factor = delay_factor self.chunked_prefill_enabled = enable_chunked_prefill + self.embedding_mode = embedding_mode self._verify_args() diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py new file mode 100644 index 0000000000000..a09d79ec3c420 --- /dev/null +++ b/vllm/core/embedding_model_block_manager.py @@ -0,0 +1,84 @@ +from typing import List, Tuple + +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.sequence import Sequence, SequenceGroup + + +class EmbeddingModelBlockSpaceManager(BlockSpaceManager): + """An embedding version of BlockSpaceManager for use in environments + with embedding models where block management is not required. + + This class provides the same interface as BlockSpaceManager, but its + methods perform no actions or return simple values like True in specific + actions. It's designed to be used in scenarios where the overhead of + block management is unnecessary, such as in an embedding environment. + """ + + def __init__( + self, + **kwargs, + ) -> None: + pass + + def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + # Always return OK for dummy purposes + return AllocStatus.OK + + def allocate(self, seq_group: SequenceGroup) -> None: + # No actual allocation logic needed + pass + + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + return True + + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + return None # type: ignore + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + pass + + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + return AllocStatus.OK + + def swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> List[Tuple[int, int]]: + return None # type: ignore + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + return True + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + return None # type: ignore + + def free(self, seq: Sequence) -> None: + # No operation on free + return + + def get_block_table(self, seq: Sequence) -> List[int]: + return None # type: ignore + + def get_num_free_gpu_blocks(self) -> int: + return 1 + + def get_num_free_cpu_blocks(self) -> int: + return 1 + + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + pass + + def get_common_computed_block_ids(self, + seq_group: SequenceGroup) -> List[int]: + return None # type: ignore + + def mark_blocks_as_computed(self, seq_group: SequenceGroup): + pass diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index b2a5e41990f39..689cbc2179ee1 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -35,6 +35,11 @@ def get_block_space_manager_class(version: str): from vllm.core.block_manager_v2 import BlockSpaceManagerV2 return BlockSpaceManagerV2 + if version == "embedding": + from vllm.core.embedding_model_block_manager import ( + EmbeddingModelBlockSpaceManager) + return EmbeddingModelBlockSpaceManager + raise ValueError(f"Unknown version {version=}") @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 35e3db18f1c43..fb6e985b2f31c 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -270,9 +270,14 @@ def __init__( self.scheduler_config.max_model_len, self.scheduler_config.max_num_batched_tokens) + version = "v1" + if self.scheduler_config.use_v2_block_manager: + version = "v2" + if self.scheduler_config.embedding_mode: + version = "embedding" + BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( - version="v2" if self.scheduler_config. - use_v2_block_manager else "v1") + version) # Create the block space manager. self.block_manager = BlockSpaceManagerImpl( @@ -968,6 +973,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: sampling_params=seq_group.sampling_params, block_tables=block_tables, do_sample=do_sample, + pooling_params=seq_group.pooling_params, token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5c2acbef13129..163723b4be364 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -574,6 +574,7 @@ def create_engine_config(self, ) -> EngineConfig: speculative_config.num_lookahead_slots), delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, + embedding_mode=model_config.embedding_mode, ) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 37a2dc77a3b50..a31f10b7748d3 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -14,7 +14,8 @@ from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput from vllm.usage.usage_lib import UsageContext @@ -47,15 +48,16 @@ def _raise_exception_on_finish( class AsyncStream: - """A stream of RequestOutputs for a request that can be - iterated over asynchronously.""" + """A stream of RequestOutputs or EmbeddingRequestOutputs for a request + that can be iterated over asynchronously.""" def __init__(self, request_id: str) -> None: self.request_id = request_id self._queue: asyncio.Queue = asyncio.Queue() self._finished = False - def put(self, item: Union[RequestOutput, Exception]) -> None: + def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, + Exception]) -> None: if self._finished: return self._queue.put_nowait(item) @@ -71,7 +73,7 @@ def finished(self) -> bool: def __aiter__(self): return self - async def __anext__(self) -> RequestOutput: + async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]: result = await self._queue.get() if isinstance(result, Exception): raise result @@ -108,7 +110,8 @@ def propagate_exception(self, self.abort_request(rid) def process_request_output(self, - request_output: RequestOutput, + request_output: Union[RequestOutput, + EmbeddingRequestOutput], *, verbose: bool = False) -> None: """Process a request output from the engine.""" @@ -196,7 +199,8 @@ def has_new_requests(self): class _AsyncLLMEngine(LLMEngine): """Extension of LLMEngine to add async methods.""" - async def step_async(self) -> List[RequestOutput]: + async def step_async( + self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. The workers are ran asynchronously if possible. @@ -251,7 +255,7 @@ async def add_request_async( self, request_id: str, prompt: Optional[str], - sampling_params: SamplingParams, + params: Union[SamplingParams, PoolingParams], prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -270,8 +274,8 @@ async def add_request_async( return self.add_request(request_id, prompt=prompt, + params=params, prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params, arrival_time=arrival_time, lora_request=lora_request, multi_modal_data=multi_modal_data) @@ -511,7 +515,7 @@ async def add_request( self, request_id: str, prompt: Optional[str], - sampling_params: SamplingParams, + params: Union[SamplingParams, PoolingParams], prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -528,9 +532,9 @@ async def add_request( max_log_len] logger.info( "Received request %s: prompt: %r, " - "sampling_params: %s, prompt_token_ids: %s, " - "lora_request: %s.", request_id, shortened_prompt, - sampling_params, shortened_token_ids, lora_request) + "params: %s, prompt_token_ids: %s, " + "lora_request: %s.", request_id, shortened_prompt, params, + shortened_token_ids, lora_request) if not self.is_running: if self.start_engine_loop: @@ -562,7 +566,7 @@ async def add_request( stream = self._request_tracker.add_request( request_id, prompt=prompt, - sampling_params=sampling_params, + params=params, prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, lora_request=lora_request, @@ -597,8 +601,8 @@ async def generate( multi_modal_data: Multi modal data per request. Yields: - The output `RequestOutput` objects from the LLMEngine for the - request. + The output `RequestOutput` objects from the LLMEngine + for the request. Details: - If the engine is not running, start the background loop, @@ -643,25 +647,123 @@ async def generate( >>> # Process and return the final output >>> ... """ - # Preprocess the request. - arrival_time = time.time() - - try: - stream = await self.add_request( + async for output in self.process_request( request_id, prompt, sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time, - lora_request=lora_request, - multi_modal_data=multi_modal_data, - ) + prompt_token_ids, + lora_request, + multi_modal_data, + ): + yield output + + async def encode( + self, + prompt: Optional[str], + pooling_params: PoolingParams, + request_id: str, + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None + ) -> AsyncIterator[EmbeddingRequestOutput]: + """Generate outputs for a request from an embedding model. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt string. Can be None if prompt_token_ids is + provided. + pooling_params: The pooling parameters of the request. + request_id: The unique id of the request. + prompt_token_ids: The token IDs of the prompt. If None, we + use the tokenizer to convert the prompts to token IDs. + lora_request: LoRA request to use for generation, if any. + multi_modal_data: Multi modal data per request. + + Yields: + The output `EmbeddingRequestOutput` objects from the LLMEngine + for the request. + + Details: + - If the engine is not running, start the background loop, + which iteratively invokes + :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step` + to process the waiting requests. + - Add the request to the engine's `RequestTracker`. + On the next background loop, this request will be sent to + the underlying engine. + Also, a corresponding `AsyncStream` will be created. + - Wait for the request outputs from `AsyncStream` and yield them. + + Example: + >>> # Please refer to entrypoints/api_server.py for + >>> # the complete example. + >>> + >>> # initialize the engine and the example input + >>> engine = AsyncLLMEngine.from_engine_args(engine_args) + >>> example_input = { + >>> "input": "What is LLM?", + >>> "request_id": 0, + >>> } + >>> + >>> # start the generation + >>> results_generator = engine.encode( + >>> example_input["input"], + >>> PoolingParams(), + >>> example_input["request_id"]) + >>> + >>> # get the results + >>> final_output = None + >>> async for request_output in results_generator: + >>> if await request.is_disconnected(): + >>> # Abort the request if the client disconnects. + >>> await engine.abort(request_id) + >>> # Return or raise an error + >>> ... + >>> final_output = request_output + >>> + >>> # Process and return the final output + >>> ... + """ + async for output in self.process_request( + request_id, + prompt, + pooling_params, + prompt_token_ids, + lora_request, + multi_modal_data, + ): + yield output + + async def process_request( + self, + request_id: str, + prompt: Optional[str], + params: Union[SamplingParams, PoolingParams], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: + """Common logic to process requests with SamplingParams or + PoolingParams.""" + arrival_time = time.time() + + stream = await self.add_request( + request_id, + prompt, + params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, + lora_request=lora_request, + multi_modal_data=multi_modal_data, + ) + try: async for request_output in stream: yield request_output except (Exception, asyncio.CancelledError) as e: - # If there is an exception or coroutine is cancelled, abort the - # request. self._abort(request_id) raise e diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index b9938b045ba2b..46fa41030b4a1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -20,9 +20,12 @@ from vllm.executor.ray_utils import initialize_ray_cluster from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, + RequestOutputFactory) +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput, +from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, + MultiModalData, PoolerOutput, SamplerOutput, Sequence, SequenceGroup, SequenceGroupMetadata, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer @@ -169,7 +172,8 @@ def __init__( load_config=load_config, ) - self._initialize_kv_caches() + if not self.model_config.embedding_mode: + self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): @@ -354,7 +358,7 @@ def add_request( self, request_id: str, prompt: Optional[str], - sampling_params: SamplingParams, + params: Union[SamplingParams, PoolingParams], prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -370,7 +374,8 @@ def add_request( request_id: The unique ID of the request. prompt: The prompt string. Can be None if prompt_token_ids is provided. - sampling_params: The sampling parameters for text generation. + params: Parameters for sampling or pooling. SamplingParams + for text generation. PoolingParams for pooling. prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. arrival_time: The arrival time of the request. If None, we use @@ -404,13 +409,6 @@ def add_request( if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") - max_logprobs = self.get_model_config().max_logprobs - if (sampling_params.logprobs - and sampling_params.logprobs > max_logprobs) or ( - sampling_params.prompt_logprobs - and sampling_params.prompt_logprobs > max_logprobs): - raise ValueError(f"Cannot request more than " - f"{max_logprobs} logprobs.") if arrival_time is None: arrival_time = time.time() prompt_token_ids = self.encode_request( @@ -432,6 +430,50 @@ def add_request( seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, eos_token_id, lora_request) + # Create a SequenceGroup based on SamplingParams or PoolingParams + if isinstance(params, SamplingParams): + seq_group = self._create_sequence_group_with_sampling( + request_id, + seq, + params, + arrival_time, + lora_request, + multi_modal_data, + ) + elif isinstance(params, PoolingParams): + seq_group = self._create_sequence_group_with_pooling( + request_id, + seq, + params, + arrival_time, + lora_request, + multi_modal_data, + ) + else: + raise ValueError( + "Either SamplingParams or PoolingParams must be provided.") + + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) + + def _create_sequence_group_with_sampling( + self, + request_id: str, + seq: Sequence, + sampling_params: SamplingParams, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> SequenceGroup: + """Creates a SequenceGroup with SamplingParams.""" + max_logprobs = self.get_model_config().max_logprobs + if (sampling_params.logprobs + and sampling_params.logprobs > max_logprobs) or ( + sampling_params.prompt_logprobs + and sampling_params.prompt_logprobs > max_logprobs): + raise ValueError(f"Cannot request more than " + f"{max_logprobs} logprobs.") + # Defensive copy of SamplingParams, which are used by the sampler, # this doesn't deep-copy LogitsProcessor objects sampling_params = sampling_params.clone() @@ -443,11 +485,35 @@ def add_request( self.generation_config_fields) # Create the sequence group. - seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time, lora_request, multi_modal_data) + seq_group = SequenceGroup(request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + sampling_params=sampling_params, + lora_request=lora_request, + multi_modal_data=multi_modal_data) - # Add the sequence group to the scheduler. - self.scheduler.add_seq_group(seq_group) + return seq_group + + def _create_sequence_group_with_pooling( + self, + request_id: str, + seq: Sequence, + pooling_params: PoolingParams, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> SequenceGroup: + """Creates a SequenceGroup with PoolingParams.""" + # Defensive copy of PoolingParams, which are used by the pooler + pooling_params = pooling_params.clone() + # Create the sequence group. + seq_group = SequenceGroup(request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + lora_request=lora_request, + multi_modal_data=multi_modal_data, + pooling_params=pooling_params) + return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: """Aborts a request(s) with the given ID. @@ -484,13 +550,25 @@ def has_unfinished_requests(self) -> bool: """Returns True if there are unfinished requests.""" return self.scheduler.has_unfinished_seqs() + def _process_sequence_group_outputs( + self, + seq_group: SequenceGroup, + outputs: List[EmbeddingSequenceGroupOutput], + ) -> None: + seq_group.embeddings = outputs[0].embeddings + + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_STOPPED + + return + def _process_model_outputs( self, - output: List[SamplerOutput], + output: List[Union[SamplerOutput, PoolerOutput]], scheduled_seq_groups: List[ScheduledSequenceGroup], ignored_seq_groups: List[SequenceGroup], seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> List[RequestOutput]: + ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Apply the model output to the sequences in the scheduled seq groups. Returns RequestOutputs that can be returned to the client. @@ -510,6 +588,9 @@ def _process_model_outputs( seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) + if self.model_config.embedding_mode: + self._process_sequence_group_outputs(seq_group, outputs) + continue self.output_processor.process_prompt_logprob(seq_group, outputs) if seq_group_meta.do_sample: @@ -519,18 +600,19 @@ def _process_model_outputs( self.scheduler.free_finished_seq_groups() # Create the outputs. - request_outputs: List[RequestOutput] = [] + request_outputs: List[Union[RequestOutput, + EmbeddingRequestOutput]] = [] for scheduled_seq_group in scheduled_seq_groups: seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - request_output = RequestOutput.from_seq_group(seq_group) + request_output = RequestOutputFactory.create(seq_group) request_outputs.append(request_output) for seq_group in ignored_seq_groups: - request_output = RequestOutput.from_seq_group(seq_group) + request_output = RequestOutputFactory.create(seq_group) request_outputs.append(request_output) return request_outputs - def step(self) -> List[RequestOutput]: + def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. .. figure:: https://i.imgur.com/sv2HssD.png @@ -570,7 +652,7 @@ def step(self) -> List[RequestOutput]: >>> while True: >>> if example_inputs: >>> req_id, prompt, sampling_params = example_inputs.pop(0) - >>> engine.add_request(str(req_id), prompt, sampling_params) + >>> engine.add_request(str(req_id),prompt,sampling_params) >>> >>> # continue the request processing >>> request_outputs = engine.step() @@ -637,12 +719,15 @@ def _get_stats( # KV Cache Usage in % num_total_gpu = self.cache_config.num_gpu_blocks - num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks() - gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) + gpu_cache_usage_sys = 0. + if num_total_gpu is not None: + num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks( + ) + gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) num_total_cpu = self.cache_config.num_cpu_blocks cpu_cache_usage_sys = 0. - if num_total_cpu > 0: + if num_total_cpu is not None and num_total_cpu > 0: num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( ) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) @@ -716,8 +801,10 @@ def _get_stats( seq.get_output_len() for seq in seq_group.get_finished_seqs() ]) - best_of_requests.append(seq_group.sampling_params.best_of) - n_requests.append(seq_group.sampling_params.n) + if seq_group.sampling_params is not None: + best_of_requests.append( + seq_group.sampling_params.best_of) + n_requests.append(seq_group.sampling_params.n) finished_reason_requests.extend([ SequenceStatus.get_finished_reason(seq.status) for seq in seq_group.get_finished_seqs() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 71620139fba39..25f4428100b27 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,13 +6,17 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine +from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.sequence import MultiModalData from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter +logger = init_logger(__name__) + class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -164,8 +168,89 @@ def generate( multi_modal_data: Multi modal data. Returns: - A list of `RequestOutput` objects containing the generated - completions in the same order as the input prompts. + A list of `RequestOutput` objects containing the + generated completions in the same order as the input prompts. + """ + if sampling_params is None: + # Use default sampling params. + sampling_params = SamplingParams() + + requests_data = self._validate_and_prepare_requests( + prompts, + sampling_params, + prompt_token_ids, + lora_request, + multi_modal_data, + ) + + # Add requests to the engine and run the engine + for request_data in requests_data: + self._add_request(**request_data) + + return self._run_engine(use_tqdm) + + def encode( + self, + prompts: Optional[Union[str, List[str]]] = None, + pooling_params: Optional[Union[PoolingParams, + List[PoolingParams]]] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + """Generates the completions for the input prompts. + + NOTE: This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: A list of prompts to generate completions for. + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. + prompt_token_ids: A list of token IDs for the prompts. If None, we + use the tokenizer to convert the prompts to token IDs. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + multi_modal_data: Multi modal data. + + Returns: + A list of `EmbeddingRequestOutput` objects containing the + generated embeddings in the same order as the input prompts. + """ + if pooling_params is None: + # Use default pooling params. + pooling_params = PoolingParams() + + requests_data = self._validate_and_prepare_requests( + prompts, + pooling_params, + prompt_token_ids, + lora_request, + multi_modal_data, + ) + + # Add requests to the engine and run the engine + for request_data in requests_data: + self._add_request(**request_data) + + return self._run_engine(use_tqdm) + + def _validate_and_prepare_requests( + self, + prompts: Optional[Union[str, List[str]]], + params: Union[Union[SamplingParams, PoolingParams], + List[Union[SamplingParams, + PoolingParams]]], # Unified parameter + prompt_token_ids: Optional[List[List[int]]] = None, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[dict]: + """Validates and prepares request data for adding to the engine. + + Ensures prompts and token IDs are consistent, and returns a list of + dictionaries with request data for further processing. """ if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " @@ -188,40 +273,43 @@ def generate( assert prompt_token_ids is not None num_requests = len(prompt_token_ids) - if sampling_params is None: - # Use default sampling params. - sampling_params = SamplingParams() - - elif isinstance(sampling_params, - list) and len(sampling_params) != num_requests: - raise ValueError("The lengths of prompts and sampling_params " + if isinstance(params, list) and len(params) != num_requests: + raise ValueError("The lengths of prompts and params " "must be the same.") if multi_modal_data: multi_modal_data.data = multi_modal_data.data.to(torch.float16) # Add requests to the engine. + requests_data = [] for i in range(num_requests): prompt = prompts[i] if prompts is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ i] - self._add_request( + + multi_modal_item = MultiModalData( + type=multi_modal_data.type, + data=multi_modal_data.data[i].unsqueeze(0), + ) if multi_modal_data else None + + requests_data.append({ + "prompt": prompt, - sampling_params[i] - if isinstance(sampling_params, list) else sampling_params, + "params": + params[i] if isinstance(params, list) else params, + "prompt_token_ids": token_ids, - lora_request=lora_request, - # Get ith image while maintaining the batch dim. - multi_modal_data=MultiModalData( - type=multi_modal_data.type, - data=multi_modal_data.data[i].unsqueeze(0)) - if multi_modal_data else None, - ) - return self._run_engine(use_tqdm) + "lora_request": + lora_request, + "multi_modal_data": + multi_modal_item, + }) + + return requests_data def _add_request( self, prompt: Optional[str], - sampling_params: SamplingParams, + params: Union[SamplingParams, PoolingParams], prompt_token_ids: Optional[List[int]], lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalData] = None, @@ -229,12 +317,14 @@ def _add_request( request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id, prompt, - sampling_params, + params, prompt_token_ids, lora_request=lora_request, multi_modal_data=multi_modal_data) - def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: + def _run_engine( + self, use_tqdm: bool + ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() @@ -245,7 +335,7 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: postfix=f"Generation Speed: {0:.2f} toks/s", ) # Run the engine. - outputs: List[RequestOutput] = [] + outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] total_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() @@ -253,10 +343,12 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: if output.finished: outputs.append(output) if use_tqdm: - total_toks += (sum( - len(stp.token_ids) for stp in output.outputs)) - spd = total_toks / pbar.format_dict["elapsed"] - pbar.postfix = f"Generation Speed: {spd:.2f} toks/s" + if isinstance(output, RequestOutput): + # Calculate tokens only for RequestOutput + total_toks += sum( + len(stp.token_ids) for stp in output.outputs) + spd = total_toks / pbar.format_dict["elapsed"] + pbar.postfix = f"Generation Speed: {spd:.2f} toks/s" pbar.update(1) if use_tqdm: pbar.close() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 362f28d05c3bb..7cd51b959a0ea 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -22,9 +22,11 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, - CompletionRequest, ErrorResponse) + CompletionRequest, + EmbeddingRequest, ErrorResponse) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext @@ -32,6 +34,8 @@ openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion +openai_serving_embedding: OpenAIServingEmbedding + logger = init_logger(__name__) _running_tasks: Set[asyncio.Task] = set() @@ -123,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) +@app.post("/v1/embeddings") +async def create_embedding(request: EmbeddingRequest, raw_request: Request): + generator = await openai_serving_embedding.create_embedding( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + else: + return JSONResponse(content=generator.model_dump()) + + if __name__ == "__main__": args = parse_args() @@ -190,7 +205,8 @@ async def authentication(request: Request, call_next): args.chat_template) openai_serving_completion = OpenAIServingCompletion( engine, model_config, served_model_names, args.lora_modules) - + openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, + served_model_names) app.root_path = args.root_path uvicorn.run(app, host=args.host, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3cd9ddad3b7b7..139c5716c7cea 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,13 +1,14 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time -from typing import Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union import torch from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Annotated +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid @@ -363,6 +364,24 @@ def check_guided_decoding_count(cls, data): return data +class EmbeddingRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/embeddings + model: str + input: Union[List[int], List[List[int]], str, List[str]] + encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$') + dimensions: Optional[int] = None + user: Optional[str] = None + + # doc: begin-embedding-pooling-params + additional_data: Optional[Any] = None + + # doc: end-embedding-pooling-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + class LogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) @@ -416,6 +435,21 @@ class CompletionStreamResponse(OpenAIBaseModel): usage: Optional[UsageInfo] = Field(default=None) +class EmbeddingResponseData(BaseModel): + index: int + object: str = "embedding" + embedding: List[float] + + +class EmbeddingResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: List[EmbeddingResponseData] + usage: UsageInfo + + class ChatMessage(OpenAIBaseModel): role: str content: str diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py new file mode 100644 index 0000000000000..7a57be0c88915 --- /dev/null +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -0,0 +1,134 @@ +import time +from typing import AsyncIterator, List, Tuple + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import (EmbeddingRequest, + EmbeddingResponse, + EmbeddingResponseData, UsageInfo) +from vllm.entrypoints.openai.serving_completion import parse_prompt_format +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.logger import init_logger +from vllm.outputs import EmbeddingRequestOutput +from vllm.utils import merge_async_iterators, random_uuid + +logger = init_logger(__name__) + +TypeTokenIDs = List[int] + + +def request_output_to_embedding_response( + final_res_batch: List[EmbeddingRequestOutput], + request_id: str, + created_time: int, + model_name: str, +) -> EmbeddingResponse: + data = [] + num_prompt_tokens = 0 + for idx, final_res in enumerate(final_res_batch): + assert final_res is not None + prompt_token_ids = final_res.prompt_token_ids + + embedding_data = EmbeddingResponseData( + index=idx, embedding=final_res.outputs.embedding) + data.append(embedding_data) + + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return EmbeddingResponse( + id=request_id, + created=created_time, + model=model_name, + data=data, + usage=usage, + ) + + +class OpenAIServingEmbedding(OpenAIServing): + + def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, + served_model_names: List[str]): + super().__init__(engine=engine, + model_config=model_config, + served_model_names=served_model_names, + lora_modules=None) + self._check_embedding_mode(model_config.embedding_mode) + + async def create_embedding(self, request: EmbeddingRequest, + raw_request: Request): + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/embeddings/create + for the API specification. This API mimics the OpenAI Embedding API. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # Return error for unsupported features. + if request.encoding_format == "base64": + return self.create_error_response( + "base64 encoding is not currently supported") + if request.dimensions is not None: + return self.create_error_response( + "dimensions is currently not supported") + + model_name = request.model + request_id = f"cmpl-{random_uuid()}" + created_time = int(time.monotonic()) + + # Schedule the request and get the result generator. + generators = [] + try: + prompt_is_tokens, prompts = parse_prompt_format(request.input) + pooling_params = request.to_pooling_params() + + for i, prompt in enumerate(prompts): + if prompt_is_tokens: + prompt_formats = self._validate_prompt_and_tokenize( + request, prompt_ids=prompt) + else: + prompt_formats = self._validate_prompt_and_tokenize( + request, prompt=prompt) + + prompt_ids, prompt_text = prompt_formats + + generators.append( + self.engine.generate(prompt_text, + pooling_params, + f"{request_id}-{i}", + prompt_token_ids=prompt_ids)) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator: AsyncIterator[Tuple[ + int, EmbeddingRequestOutput]] = merge_async_iterators(*generators) + + # Non-streaming response + final_res_batch: EmbeddingRequestOutput = [None] * len(prompts) + async for i, res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(f"{request_id}-{i}") + # TODO: Use a vllm-specific Validation Error + return self.create_error_response("Client disconnected") + final_res_batch[i] = res + response = request_output_to_embedding_response( + final_res_batch, request_id, created_time, model_name) + + return response + + def _check_embedding_mode(self, embedding_mode: bool): + if not embedding_mode: + logger.warning( + "embedding_mode is False. Embedding API will not work.") + else: + logger.info("Activating the server engine with embedding enabled.") diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index f10718c5f3d80..58a1c2f7e73fe 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -9,7 +9,8 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - CompletionRequest, ErrorResponse, + CompletionRequest, + EmbeddingRequest, ErrorResponse, LogProbs, ModelCard, ModelList, ModelPermission) from vllm.logger import init_logger @@ -165,7 +166,8 @@ def _maybe_get_lora( def _validate_prompt_and_tokenize( self, - request: Union[ChatCompletionRequest, CompletionRequest], + request: Union[ChatCompletionRequest, CompletionRequest, + EmbeddingRequest], prompt: Optional[str] = None, prompt_ids: Optional[List[int]] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None @@ -191,6 +193,16 @@ def _validate_prompt_and_tokenize( prompt_ids) token_num = len(input_ids) + # Note: EmbeddingRequest doesn't have max_tokens + if isinstance(request, EmbeddingRequest): + if token_num > self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{token_num} tokens in the input for embedding " + f"generation. Please reduce the length of the input.", ) + return input_ids, input_text + if request.max_tokens is None: if token_num >= self.max_model_len: raise ValueError( diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index fa3480fa64837..2b72b31b5f070 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,9 +1,9 @@ -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Union from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) from vllm.worker.worker_base import WorkerWrapperBase @@ -123,8 +123,8 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> List[Union[SamplerOutput, PoolerOutput]]: output = self.driver_worker.execute_model(execute_model_req) return output @@ -150,7 +150,7 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): async def execute_model_async( self, execute_model_req: ExecuteModelRequest, - ) -> List[SamplerOutput]: + ) -> List[Union[SamplerOutput, PoolerOutput]]: output = await make_async(self.driver_worker.execute_model )(execute_model_req=execute_model_req, ) return output diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py new file mode 100644 index 0000000000000..445b30b8c6e9b --- /dev/null +++ b/vllm/model_executor/layers/pooler.py @@ -0,0 +1,56 @@ +from enum import IntEnum + +import torch +import torch.nn as nn + +from vllm.model_executor.pooling_metadata import (PoolingMetadata, + PoolingTensors) +from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput + + +class PoolingType(IntEnum): + """Enumeration for different types of pooling methods.""" + LAST = 0 + + +class Pooler(nn.Module): + """A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + + Attributes: + pooling_type: The type of pooling to use (LAST, AVERAGE, MAX). + normalize: Whether to normalize the pooled data. + """ + + def __init__(self, pooling_type: PoolingType, normalize: bool): + super().__init__() + self.pooling_type = pooling_type + self.normalize = normalize + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + """Pools specific information from hidden states based on metadata.""" + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + if self.pooling_type == PoolingType.LAST: + last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 + pooled_data = hidden_states[last_token_flat_indices] + else: + raise ValueError(f"Invalid pooling type: {self.pooling_type}") + + if self.normalize: + pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) + + pooled_outputs = [ + EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data + ] + + return PoolerOutput(outputs=pooled_outputs) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index e52e350d2726f..c8bab46c83eca 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -10,8 +10,9 @@ SamplingTensors, SequenceGroupToSample) from vllm.sampling_params import SamplingType -from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, - SamplerOutput, SequenceGroupOutput, SequenceOutput) +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + PromptLogprobs, SampleLogprobs, SamplerOutput, + SequenceOutput) # (num_token_ids, num_parent_ids) per sequence group. SampleResultType = List[Tuple[List[int], List[int]]] @@ -1019,7 +1020,7 @@ def _build_sampler_output( seq_outputs.append( SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) sampler_output.append( - SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) + CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs)) # If not specified, store None values in SamplerOutput. if on_device_tensors is not None: diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index d5263b500fe0f..6aec104be8da4 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -9,7 +9,7 @@ logger = init_logger(__name__) # Architecture -> (module, class). -_MODELS = { +_GENERATION_MODELS = { "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b @@ -58,6 +58,12 @@ "XverseForCausalLM": ("xverse", "XverseForCausalLM"), } +_EMBEDDING_MODELS = { + "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), +} + +_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} + # Architecture -> type. # out of tree models _OOT_MODELS: Dict[str, Type[nn.Module]] = {} @@ -114,6 +120,10 @@ def register_model(model_arch: str, model_cls: Type[nn.Module]): global _OOT_MODELS _OOT_MODELS[model_arch] = model_cls + @staticmethod + def is_embedding_model(model_arch: str) -> bool: + return model_arch in _EMBEDDING_MODELS + __all__ = [ "ModelRegistry", diff --git a/vllm/model_executor/models/llama_embedding.py b/vllm/model_executor/models/llama_embedding.py new file mode 100644 index 0000000000000..8f1c77da50d96 --- /dev/null +++ b/vllm/model_executor/models/llama_embedding.py @@ -0,0 +1,87 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import PoolerOutput + + +class LlamaEmbeddingModel(nn.Module): + """A model that uses Llama with additional embedding functionalities. + + This class encapsulates the LlamaModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of LlamaModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__() + self.model = LlamaModel(**kwargs) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model.forward(input_ids, positions, kv_caches, + attn_metadata, inputs_embeds) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.model.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/pooling_metadata.py b/vllm/model_executor/pooling_metadata.py new file mode 100644 index 0000000000000..b86cafce85d12 --- /dev/null +++ b/vllm/model_executor/pooling_metadata.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import torch + +from vllm.pooling_params import PoolingParams +from vllm.utils import is_pin_memory_available + + +class PoolingMetadata: + """Metadata for pooling operations in the Pooler layer. + + This class holds the necessary information for pooling operations, + providing context for how to perform pooling and other related operations. + + Attributes: + seq_groups: List of (seq_ids, pooling_params). + seq_data: A mapping of sequence ID to additional sequence data. + prompt_lens: List of the lengths of each prompt. + """ + + def __init__( + self, + seq_groups: List[Tuple[List[int], PoolingParams]], + seq_data: Dict[int, Any], # Specific data related to sequences + prompt_lens: List[int], + ) -> None: + self.seq_groups = seq_groups + self.seq_data = seq_data + self.prompt_lens = prompt_lens + + def __repr__(self) -> str: + return ("PoolingMetadata(" + f"seq_groups={self.seq_groups}, " + f"seq_data={self.seq_data}, " + f"prompt_lens={self.prompt_lens})") + + +@dataclass +class PoolingTensors: + """Tensors for pooling.""" + + prompt_lens: torch.Tensor + + @classmethod + def from_pooling_metadata( + cls, + pooling_metadata: "PoolingMetadata", + device: torch.device, + ) -> "PoolingTensors": + """ + Create PoolingTensors from PoolingMetadata. + + Args: + pooling_metadata: PoolingMetadata instance to convert. + device: Device to store the tensors. + """ + # Convert prompt lengths to tensor + pin_memory = is_pin_memory_available() + + prompt_lens_t = torch.tensor( + pooling_metadata.prompt_lens, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + + return cls(prompt_lens=prompt_lens_t.to(device=device, + non_blocking=True), ) diff --git a/vllm/outputs.py b/vllm/outputs.py index d01be0eb0efd2..f9bce9e683f22 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -57,8 +57,27 @@ def __repr__(self) -> str: f"stop_reason={self.stop_reason})") +class EmbeddingOutput: + """The output data of one completion output of a request. + + Args: + embedding: The embedding vector, which is a list of floats. The + length of vector depends on the model as listed in the embedding guide. + """ + + def __init__( + self, + embedding: List[float], + ) -> None: + self.embedding = embedding + + def __repr__(self) -> str: + return (f"EmbeddingOutput(" + f"embedding={len(self.embedding)}") + + class RequestOutput: - """The output data of a request to the LLM. + """The output data of a completion request to the LLM. Args: request_id: The unique ID of the request. @@ -93,6 +112,9 @@ def __init__( @classmethod def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": + if seq_group.sampling_params is None: + raise ValueError( + "Sampling parameters are missing for a CompletionRequest.") seqs = seq_group.get_seqs() if len(seqs) == 1: top_n_seqs = seqs @@ -148,3 +170,61 @@ def __repr__(self) -> str: f"finished={self.finished}, " f"metrics={self.metrics}, " f"lora_request={self.lora_request})") + + +class EmbeddingRequestOutput: + """ + The output data of an embedding request to the LLM. + + Args: + request_id (str): A unique identifier for the embedding request. + outputs (EmbeddingOutput): The embedding results for the given input. + prompt_token_ids (List[int]): A list of token IDs used in the prompt. + finished (bool): A flag indicating whether the embedding is completed. + """ + + def __init__(self, request_id: str, outputs: 'EmbeddingOutput', + prompt_token_ids: List[int], finished: bool): + self.request_id = request_id + self.prompt_token_ids = prompt_token_ids + self.finished = finished + self.outputs = outputs + + @classmethod + def from_seq_group(cls, + seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput": + if seq_group.embeddings is None: + raise ValueError( + "Embeddings are missing in seq_group for EmbeddingRequest.") + output = EmbeddingOutput(seq_group.embeddings) + prompt_token_ids = seq_group.prompt_token_ids + finished = seq_group.is_finished() + + return cls(seq_group.request_id, output, prompt_token_ids, finished) + + def __repr__(self): + """ + Returns a string representation of an EmbeddingRequestOutput instance. + + The representation includes the request_id and the number of outputs, + providing a quick overview of the embedding request's results. + + Returns: + str: A string representation of the EmbeddingRequestOutput instance. + """ + return (f"EmbeddingRequestOutput(request_id='{self.request_id}', " + f"outputs={repr(self.outputs)}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"finished={self.finished})") + + +class RequestOutputFactory: + + @staticmethod + def create(seq_group): + # Determine the type based on a condition, for example: + if hasattr(seq_group, + 'embeddings') and seq_group.embeddings is not None: + return EmbeddingRequestOutput.from_seq_group(seq_group) + else: + return RequestOutput.from_seq_group(seq_group) diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py new file mode 100644 index 0000000000000..3b95d73ddc2c5 --- /dev/null +++ b/vllm/pooling_params.py @@ -0,0 +1,20 @@ +from typing import Any, Optional + + +class PoolingParams: + """Pooling parameters for pooling. + + Attributes: + additional_data: Any additional data needed for pooling. + """ + + def __init__(self, additional_data: Optional[Any] = None): + self.additional_data = additional_data + + def clone(self) -> "PoolingParams": + """Returns a deep copy of the PoolingParams instance.""" + return PoolingParams(additional_data=self.additional_data, ) + + def __repr__(self) -> str: + return (f"PoolingParams(" + f"additional_metadata={self.additional_data})") diff --git a/vllm/sequence.py b/vllm/sequence.py index 3cebb85b49d27..46ac33b7ecabd 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,11 +1,13 @@ """Sequence and its related classes.""" import copy import enum +from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from vllm.block import LogicalTokenBlock from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams if TYPE_CHECKING: @@ -375,12 +377,12 @@ class SequenceGroupState: class MultiModalData: """Multi modal request. - + Args: type: The data type. data: The actual data. The required shape and semantic meaning of it depends on the vision - language config of the hosted model. + language config of the hosted model. See `VisionLanguageConfig` in `config.py`. """ @@ -402,16 +404,22 @@ class SequenceGroup: arrival_time: The arrival time of the request. lora_request: LoRA request. multi_modal_data: Multi modal data associated with the request. + embeddings: The embeddings vectors of the prompt of the sequence group + for an embedding model. + pooling_params: The pooling parameters used to generate the pooling + for an embedding model. """ def __init__( self, request_id: str, seqs: List[Sequence], - sampling_params: SamplingParams, arrival_time: float, + sampling_params: Optional[SamplingParams] = None, lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalData] = None, + embeddings: Optional[List[float]] = None, + pooling_params: Optional[PoolingParams] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -425,6 +433,8 @@ def __init__( self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() self.multi_modal_data = multi_modal_data + self.embeddings = embeddings + self.pooling_params = pooling_params @property def prompt(self) -> str: @@ -479,12 +489,13 @@ def set_finished_time(self, time: Optional[float]) -> None: def get_max_num_running_seqs(self) -> int: """The maximum number of sequences running in parallel in the remaining lifetime of the request.""" - if self.sampling_params.use_beam_search: + if self.sampling_params and self.sampling_params.use_beam_search: # For beam search, maximally there will always be `best_of` beam # candidates running in the future. return self.sampling_params.best_of else: - if self.sampling_params.best_of > self.num_seqs(): + if (self.sampling_params + and self.sampling_params.best_of > self.num_seqs()): # At prompt stage, the sequence group is not yet filled up # and only have one sequence running. However, in the # generation stage, we will have `best_of` sequences running. @@ -555,7 +566,7 @@ def is_finished(self) -> bool: return all(seq.is_finished() for seq in self.get_seqs()) def is_prefill(self) -> bool: - # Every sequences should be in the same stage. + # Every sequence should be in the same stage. return self.get_seqs()[0].is_prefill() def __repr__(self) -> str: @@ -594,6 +605,7 @@ def __init__( sampling_params: SamplingParams, block_tables: Dict[int, List[int]], do_sample: bool = True, + pooling_params: Optional[PoolingParams] = None, token_chunk_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, @@ -605,6 +617,7 @@ def __init__( self.seq_data = seq_data self.sampling_params = sampling_params self.block_tables = block_tables + self.pooling_params = pooling_params self.lora_request = lora_request self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data @@ -669,8 +682,20 @@ def __eq__(self, other: object) -> bool: return equal and log_probs_equal -class SequenceGroupOutput: - """The model output associated with a sequence group.""" +class SequenceGroupOutput(ABC): + """The base class for model outputs associated with a sequence group.""" + + @abstractmethod + def __repr__(self) -> str: + pass + + @abstractmethod + def __eq__(self, other: object) -> bool: + pass + + +class CompletionSequenceGroupOutput(SequenceGroupOutput): + """The model output associated with a completion sequence group.""" def __init__( self, @@ -682,26 +707,45 @@ def __init__( self.prompt_logprobs = prompt_logprobs def __repr__(self) -> str: - return (f"SequenceGroupOutput(samples={self.samples}, " + return (f"CompletionSequenceGroupOutput(samples={self.samples}, " f"prompt_logprobs={self.prompt_logprobs})") def __eq__(self, other: object) -> bool: - if not isinstance(other, SequenceGroupOutput): + if not isinstance(other, CompletionSequenceGroupOutput): raise NotImplementedError() return (self.samples == other.samples and self.prompt_logprobs == other.prompt_logprobs) +class EmbeddingSequenceGroupOutput(SequenceGroupOutput): + """The model output associated with an embedding sequence group.""" + + def __init__( + self, + embeddings: List[float], + ) -> None: + self.embeddings = embeddings + + def __repr__(self) -> str: + return (f"EmbeddingSequenceGroupOutput(" + f"embeddings_shape={len(self.embeddings)})") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, EmbeddingSequenceGroupOutput): + raise NotImplementedError() + return self.embeddings == other.embeddings + + @dataclass class SamplerOutput: """For each sequence group, we generate a list of SequenceOutput object, each of which contains one possible candidate for the next token. - This datastructure implements methods so it can be used like a list, but + This data structure implements methods, so it can be used like a list, but also has optional fields for device tensors. """ - outputs: List[SequenceGroupOutput] + outputs: List[CompletionSequenceGroupOutput] # On-device tensor containing probabilities of each token. sampled_token_probs: Optional["torch.Tensor"] = None @@ -742,6 +786,27 @@ def __repr__(self) -> str: f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") +@dataclass +class PoolerOutput: + """The output from a pooling operation in the embedding model.""" + outputs: List[EmbeddingSequenceGroupOutput] + + spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None + + def __getitem__(self, idx: int): + return self.outputs[idx] + + def __setitem__(self, idx: int, value): + self.outputs[idx] = value + + def __len__(self): + return len(self.outputs) + + def __eq__(self, other: object): + return isinstance(other, + self.__class__) and self.outputs == other.outputs + + @dataclass class ExecuteModelRequest: """The model execution request.""" diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index d6f80c82b80bf..4dc6c49eb58d2 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -4,7 +4,8 @@ import torch -from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata, +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SamplerOutput, SequenceGroupMetadata, SequenceGroupOutput, SequenceOutput) SeqId = int @@ -94,7 +95,7 @@ def create_sequence_group_output( for topk_logprob_index, _ in enumerate(topk_token_ids) }) - return SequenceGroupOutput( + return CompletionSequenceGroupOutput( samples=[ SequenceOutput(parent_seq_id=seq_id, output_token=token_id, diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py new file mode 100644 index 0000000000000..2d3f160c60dc1 --- /dev/null +++ b/vllm/worker/embedding_model_runner.py @@ -0,0 +1,266 @@ +from typing import Dict, List, Optional, Set, Tuple + +import torch + +from vllm.attention import AttentionMetadata +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.distributed import broadcast_tensor_dict +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.pooling_params import PoolingParams +from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata +from vllm.worker.model_runner import BatchType, ModelRunner + +logger = init_logger(__name__) + + +class EmbeddingModelRunner(ModelRunner): + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + vision_language_config: Optional[VisionLanguageConfig] = None, + ): + super().__init__(model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config=lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + vision_language_config=vision_language_config) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + kv_caches: List[torch.Tensor], + ) -> Optional[PoolerOutput]: + (input_tokens, input_positions, attn_metadata, pooling_metadata, + lora_requests, lora_mapping, multi_modal_input + ) = self.prepare_input_tensors(seq_group_metadata_list) + + if self.lora_config: + self.set_active_loras(lora_requests, lora_mapping) + + # Currently cuda graph is only supported by the decode phase. + prefill_meta = attn_metadata.prefill_metadata + decode_meta = attn_metadata.decode_metadata + if prefill_meta is None and decode_meta.use_cuda_graph: + graph_batch_size = input_tokens.shape[0] + model_executable = self.graph_runners[graph_batch_size] + else: + model_executable = self.model + + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [None] * num_layers + + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + } + if self.vision_language_config: + execute_model_kwargs.update({"image_input": multi_modal_input}) + hidden_states = model_executable(**execute_model_kwargs) + + return self.model.pooler(hidden_states=hidden_states, + pooling_metadata=pooling_metadata) + + def prepare_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, + Set[LoRARequest], LoRAMapping, torch.Tensor]: + if self.is_driver_worker: + prefill_reqs = [] + decode_reqs = [] + for seq_group_meta in seq_group_metadata_list: + if seq_group_meta.is_prompt: + prefill_reqs.append(seq_group_meta) + else: + decode_reqs.append(seq_group_meta) + + # Prepare input tensors. + ( + input_tokens, + input_positions, + prefill_attn_metadata, + prompt_lens, + subquery_lens, + lora_index_mapping, + lora_prompt_mapping, + lora_requests, + multi_modal_input, + slot_mapping, + ) = self._prepare_prompt(prefill_reqs) + ( + decode_input_tokens, + decode_input_positions, + decode_attn_metadata, + decode_lora_index_mapping, + decode_lora_prompt_mapping, + decode_lora_requests, + decode_slot_mapping, + ) = self._prepare_decode(decode_reqs) + + # Prepare PoolingMetadata + pooling_metadata = self._prepare_pooling(seq_group_metadata_list, + prompt_lens) + + if not self.scheduler_config.chunked_prefill_enabled: + assert (len(prefill_reqs) and len(decode_reqs)) == 0 + + num_prefills = len(prompt_lens) + num_prefill_tokens = len(input_tokens) + num_decode_tokens = len(decode_input_tokens) + + # Coalesce tensors. Note that attn_metadata is currently not + # coalesced for simplicity. + input_tokens.extend(decode_input_tokens) + input_positions.extend(decode_input_positions) + slot_mapping.extend(decode_slot_mapping) + lora_index_mapping.extend(decode_lora_index_mapping) + lora_prompt_mapping.extend(decode_lora_prompt_mapping) + lora_requests.update(decode_lora_requests) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + + if self.lora_config: + lora_mapping = LoRAMapping( + lora_index_mapping, + lora_prompt_mapping, + ) + else: + lora_mapping = None + + # Broadcast the metadata. + # If batch contains both prefill and decode, it sends 2 broadcasts. + # If it only contains 1 type, it triggers a single broadcast. + if (prefill_attn_metadata is not None + and decode_attn_metadata is not None): + batch_type = BatchType.MIXED + elif prefill_attn_metadata is not None: + batch_type = BatchType.PREFILL + else: + batch_type = BatchType.DECODE + + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "lora_requests": lora_requests, + "lora_mapping": lora_mapping, + "multi_modal_input": multi_modal_input, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, + "batch_type": batch_type, + } + if prefill_attn_metadata is not None: + metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) + else: + assert decode_attn_metadata is not None + metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) + broadcast_tensor_dict(metadata_dict, src=0) + + # Broadcast decode attn metadata for mixed batch type. + # The additional broadcast costs 300us overhead on 4 A10 GPUs. + # We can potentially reduce the overhead by coelescing tensors. + if batch_type == BatchType.MIXED: + assert decode_attn_metadata is not None + metadata_dict = decode_attn_metadata.asdict_zerocopy() + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + slot_mapping = metadata_dict.pop("slot_mapping") + num_prefills = metadata_dict.pop("num_prefills") + lora_mapping = metadata_dict.pop("lora_mapping") + lora_requests = metadata_dict.pop("lora_requests") + multi_modal_input = metadata_dict.pop("multi_modal_input") + num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") + num_decode_tokens = metadata_dict.pop("num_decode_tokens") + batch_type = metadata_dict.pop("batch_type") + + # Create an attention metadata. + prefill_attn_metadata = None + decode_attn_metadata = None + if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: + prefill_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + else: + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + + pooling_metadata = PoolingMetadata(seq_groups=None, + seq_data=None, + prompt_lens=None) + + # if it is a mixed batch, decode attn_metadata is broadcasted + # separately. + if batch_type == BatchType.MIXED: + metadata_dict = broadcast_tensor_dict(src=0) + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + + attn_metadata = AttentionMetadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + prefill_metadata=prefill_attn_metadata, + decode_metadata=decode_attn_metadata, + kv_cache_dtype=self.kv_cache_dtype, + ) + + return (input_tokens, input_positions, attn_metadata, pooling_metadata, + lora_requests, lora_mapping, multi_modal_input) + + def _prepare_pooling( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> PoolingMetadata: + """Prepare PoolingMetadata for the sequence group metadata list.""" + seq_groups: List[Tuple[List[int], PoolingParams]] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + pooling_params = seq_group_metadata.pooling_params + seq_groups.append((seq_ids, pooling_params)) + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + pooling_metadata = PoolingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + ) + + return pooling_metadata diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 3fc76c6142165..21d76fd531e49 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,6 +1,6 @@ import time from enum import IntEnum -from typing import Dict, List, NamedTuple, Optional, Set, Tuple +from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union import numpy as np import torch @@ -287,18 +287,18 @@ def _prepare_prompt( lora_requests.add(seq_group_metadata.lora_request) lora_index_mapping += [lora_id] * (seq_len - context_len) - lora_prompt_mapping.extend( - [lora_id] * - (seq_len - context_len - if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + lora_prompt_mapping.extend([lora_id] * ( + seq_len - context_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.multi_modal_data: multi_modal_input_list.append( seq_group_metadata.multi_modal_data.data) - if seq_group_metadata.block_tables is None: + if _is_block_tables_empty(seq_group_metadata.block_tables): # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. + # In embeddings, the block tables are {seq_id: None}. slot_mapping.extend([_PAD_SLOT_ID] * seq_len) continue @@ -813,7 +813,6 @@ def profile_run(self) -> None: sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs - # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request @@ -1139,3 +1138,15 @@ def _prepare_fake_inputs( prompt_tokens = [0] * seq_len fake_image_input = None return SequenceData(prompt_tokens), fake_image_input + + +def _is_block_tables_empty(block_tables: Union[None, Dict]): + """ + Check if block_tables is None or a dictionary with all None values. + """ + if block_tables is None: + return True + if isinstance(block_tables, dict) and all( + value is None for value in block_tables.values()): + return True + return False diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 0ca9c2b64cf30..e4fbc877b8c9f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch import torch.distributed @@ -16,8 +16,9 @@ init_custom_ar) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.worker.cache_engine import CacheEngine +from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.model_runner import ModelRunner from vllm.worker.worker_base import WorkerBase @@ -68,7 +69,9 @@ def __init__( assert not self.lora_config, ( "To be tested: vision language model with LoRA settings.") - self.model_runner = ModelRunner( + ModelRunnerClass = (EmbeddingModelRunner if + self.model_config.embedding_mode else ModelRunner) + self.model_runner = ModelRunnerClass( model_config, parallel_config, scheduler_config, @@ -83,7 +86,8 @@ def __init__( # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: CacheEngine - self.gpu_cache: List[torch.Tensor] + # Initialize gpu_cache as embedding models don't initialize kv_caches + self.gpu_cache: Optional[List[torch.tensor]] = None def init_device(self) -> None: if self.device_config.device.type == "cuda": @@ -209,7 +213,7 @@ def cache_swap( def execute_model( self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: + ) -> List[Union[SamplerOutput, PoolerOutput]]: if execute_model_req is None: seq_group_metadata_list = None From 6eaccb7353cfe84d77981da726f6d82a8aefd2be Mon Sep 17 00:00:00 2001 From: Yikang Shen Date: Sun, 12 May 2024 00:27:24 -0400 Subject: [PATCH 028/124] [Model] Add support for IBM Granite Code models (#4636) --- vllm/model_executor/models/llama.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f6d7fc8733fce..127e4612b2e40 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -58,15 +58,16 @@ def __init__( intermediate_size: int, hidden_act: str, quant_config: Optional[QKVParallelLinear] = None, + bias: bool = False, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, - bias=False, + bias=bias, quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, - bias=False, + bias=bias, quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -209,6 +210,7 @@ def __init__( intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -348,6 +350,8 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, From a709e87a4f35c1637d2bbea4bc2c9e5fe7fd70b5 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Sun, 12 May 2024 20:46:31 -0400 Subject: [PATCH 029/124] [CI/Build] Tweak Marlin Nondeterminism Issues (#4713) --- tests/models/test_gptq_marlin.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py index 4d73843f970c4..b1c2b88bc99af 100644 --- a/tests/models/test_gptq_marlin.py +++ b/tests/models/test_gptq_marlin.py @@ -1,13 +1,11 @@ """Compares the outputs of gptq vs gptq_marlin Note: GPTQ and Marlin do not have bitwise correctness. As a result, in this test, we just confirm that the top selected tokens of the -Marlin/GPTQ models are in the top 3 selections of each other. +Marlin/GPTQ models are in the top 5 selections of each other. Note: Marlin internally uses locks to synchronize the threads. This can result in very slight nondeterminism for Marlin. As a result, we re-run the test up to 3 times to see if we pass. -Note: This test currently fails running with --forked with the following: - RuntimeError: Cannot re-initialize CUDA in forked subprocess. - To use CUDA with multiprocessing, you must use the 'spawn' start method + Run `pytest tests/models/test_gptq_marlin.py`. """ import os @@ -49,7 +47,7 @@ ] -@pytest.mark.flaky(reruns=2) +@pytest.mark.flaky(reruns=3) @pytest.mark.skipif(gptq_marlin_not_supported, reason="gptq_marlin is not supported on this GPU type.") @pytest.mark.parametrize("model", MODELS) @@ -75,7 +73,7 @@ def test_models( tensor_parallel_size=1) gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts[:-1], max_tokens, num_logprobs) del gptq_marlin_model # Run gptq. @@ -85,7 +83,7 @@ def test_models( quantization="gptq", max_model_len=MAX_MODEL_LEN, tensor_parallel_size=1) - gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, + gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts[:-1], max_tokens, num_logprobs) del gptq_model From a7be4d00725db5ae4f738f70c3a89fd9dedaf7ec Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Sun, 12 May 2024 20:47:47 -0400 Subject: [PATCH 030/124] [CORE] Improvement in ranks code (#4718) --- vllm/model_executor/layers/sampler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c8bab46c83eca..a84f562909d50 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -681,7 +681,9 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """ vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), indices] - return (x > vals[:, None]).long().sum(1).add_(1) + result = (x > vals[:, None]) + del vals + return result.sum(1).add_(1) def _get_logprobs( From 702bee461f448b0186eb9d673baad29fd923c884 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 12 May 2024 17:47:59 -0700 Subject: [PATCH 031/124] [Core][Distributed] refactor custom allreduce to support multiple tp groups (#4754) --- tests/distributed/test_comm_ops.py | 22 +- tests/distributed/test_custom_all_reduce.py | 87 +++-- tests/distributed/test_pynccl.py | 4 +- vllm/distributed/communication_op.py | 45 ++- .../device_communicators/custom_all_reduce.py | 320 ++++++++++-------- .../device_communicators/pynccl.py | 4 +- vllm/distributed/parallel_state.py | 28 +- vllm/test_utils.py | 17 +- vllm/worker/model_runner.py | 15 +- vllm/worker/worker.py | 11 +- 10 files changed, 327 insertions(+), 226 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 9a7a1f07e1b8d..a4423bbfddf46 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -16,7 +16,7 @@ @ray.remote(num_gpus=1, max_calls=1) -def all_reduce_test_worker(tensor_parallel_size: int, rank: int, +def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int, distributed_init_port: str): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs @@ -24,12 +24,12 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_elements = 8 all_tensors = [ torch.arange(num_elements, dtype=torch.float32, device="cuda") * - (r + 1) for r in range(tensor_parallel_size) + (r + 1) for r in range(tp_size) ] expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) t = all_tensors[rank] @@ -38,7 +38,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, @ray.remote(num_gpus=1, max_calls=1) -def all_gather_test_worker(tensor_parallel_size: int, rank: int, +def all_gather_test_worker(tp_size: int, pp_size: int, rank: int, distributed_init_port: str): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs @@ -46,7 +46,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_dimensions = 3 tensor_size = list(range(2, num_dimensions + 2)) @@ -57,7 +57,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, all_tensors = [ torch.arange(total_size, dtype=torch.float32, device="cuda").reshape(tensor_size) * (r + 1) - for r in range(tensor_parallel_size) + for r in range(tp_size) ] expected = torch.cat(all_tensors, dim=all_gather_dimension) t = all_tensors[rank] @@ -66,7 +66,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, @ray.remote(num_gpus=1, max_calls=1) -def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, +def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, distributed_init_port: str): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs @@ -74,7 +74,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) test_dict = { # device tensor @@ -106,10 +106,10 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("tensor_parallel_size", [2]) +@pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("test_target", [ all_reduce_test_worker, all_gather_test_worker, broadcast_tensor_dict_test_worker ]) -def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): - multi_process_tensor_parallel(tensor_parallel_size, test_target) +def test_multi_process_tensor_parallel(tp_size, test_target): + multi_process_tensor_parallel(tp_size, 1, test_target) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 308b874280f55..bdca031e39be1 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -6,8 +6,10 @@ import torch import torch.distributed as dist -from vllm.distributed import tensor_model_parallel_all_reduce -from vllm.distributed.device_communicators import custom_all_reduce +from vllm.distributed.communication_op import ( # noqa + graph_capture, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, + get_tp_ca_communicator) from vllm.test_utils import (init_test_distributed_environment, multi_process_tensor_parallel) @@ -18,17 +20,36 @@ @ray.remote(num_gpus=1, max_calls=1) -def graph_allreduce(world_size, rank, distributed_init_port): +def graph_allreduce(tp_size, pp_size, rank, distributed_init_port): del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, world_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) - custom_all_reduce.init_custom_ar() + group = get_tensor_model_parallel_group() + + # A small all_reduce for warmup. + # this is needed because device communicators might be created lazily + # (e.g. NCCL). This will ensure that the communicator is initialized + # before any communication happens, so that this group can be used for + # graph capture immediately. + data = torch.zeros(1) + data = data.to(device=device) + torch.distributed.all_reduce(data, group=group) + torch.cuda.synchronize() + del data + + # we use the first group to communicate once + # and the second group to communicate twice + # and so on + # this is used to demonstrate that each group can + # communicate independently + num_communication = rank // tp_size + 1 + for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: - with custom_all_reduce.capture(): + with graph_capture(): # use integers so result matches NCCL exactly inp1 = torch.randint(1, 16, (sz, ), @@ -41,44 +62,52 @@ def graph_allreduce(world_size, rank, distributed_init_port): torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): - out1 = tensor_model_parallel_all_reduce(inp1) - # the input buffer is immediately modified to test - # synchronization - dist.all_reduce(inp1) - out2 = tensor_model_parallel_all_reduce(inp2) - dist.all_reduce(inp2) + for i in range(num_communication): + out1 = tensor_model_parallel_all_reduce(inp1) + # the input buffer is immediately modified to test + # synchronization + dist.all_reduce(inp1, group=group) + out2 = tensor_model_parallel_all_reduce(inp2) + dist.all_reduce(inp2, group=group) graph.replay() assert torch.allclose(out1, inp1) assert torch.allclose(out2, inp2) @ray.remote(num_gpus=1, max_calls=1) -def eager_allreduce(world_size, rank, distributed_init_port): +def eager_allreduce(tp_size, pp_size, rank, distributed_init_port): del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, world_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) + # we use the first group to communicate once + # and the second group to communicate twice + # and so on + # this is used to demonstrate that each group can + # communicate independently + num_communication = rank // tp_size + 1 sz = 1024 - custom_all_reduce.init_custom_ar() - fa = custom_all_reduce.get_handle() + fa = get_tp_ca_communicator() inp = torch.ones(sz, dtype=torch.float32, device=device) - out = fa.all_reduce_unreg(inp) - assert torch.allclose(out, inp * world_size) + out = inp + for _ in range(num_communication): + out = fa.all_reduce_unreg(out) + assert torch.allclose(out, inp * (tp_size**num_communication)) inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) - out = fa.all_reduce_unreg(inp) - assert torch.allclose(out, inp * world_size) + out = inp + for _ in range(num_communication): + out = fa.all_reduce_unreg(out) + assert torch.allclose(out, inp * (tp_size**num_communication)) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("tensor_parallel_size", [2]) +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) @pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce]) -def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): - multi_process_tensor_parallel(tensor_parallel_size, test_target) - - -if __name__ == "__main__": - multi_process_tensor_parallel(2, graph_allreduce) +def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target): + world_size = tp_size * pipeline_parallel_size + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index b3e30a0434423..a0f7500bf0ee9 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -5,7 +5,7 @@ import torch from vllm.distributed.communication_op import ( # noqa - graph_capture_mode, tensor_model_parallel_all_reduce) + graph_mode, tensor_model_parallel_all_reduce) from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, @@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") ensure_model_parallel_initialized(2, 2) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - with graph_capture_mode(): + with graph_mode(): # two tp groups can communicate independently if torch.distributed.get_rank() in [0, 1]: tensor = tensor_model_parallel_all_reduce(tensor) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 32ab5694e5390..9cc776f8324f2 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,5 +1,5 @@ from collections import namedtuple -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -9,12 +9,13 @@ get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_ca_communicator, get_tp_pynccl_communicator) @contextmanager -def graph_capture_mode(): - # In graph capture, we have to be very careful about the collective +def graph_mode(): + # In graph mode, we have to be very careful about the collective # operations. The current status is: # allreduce \ Mode | Eager | Graph | # -------------------------------------------- @@ -24,10 +25,32 @@ def graph_capture_mode(): # # Note that custom allreduce will have a runtime check, if the tensor size # is too large, it will fallback to the next available option. + # In summary: When using CUDA graph, we use + # either custom all-reduce kernel or pynccl. When not using CUDA + # graph, we use either custom all-reduce kernel or PyTorch NCCL. + # We always prioritize using custom all-reduce kernel but fall back + # to PyTorch or pynccl if it is disabled or not supported. pynccl_comm = get_tp_pynccl_communicator() - assert pynccl_comm is not None - with pynccl_comm.change_state(enable=True, - stream=torch.cuda.current_stream()): + if pynccl_comm is None: + context = nullcontext() + else: + context = pynccl_comm.change_state(enable=True, + stream=torch.cuda.current_stream()) + with context: + yield + + +@contextmanager +def graph_capture(): + """ + `graph_capture` is a context manager which should include the code that + is capturing the CUDA graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. + """ + ca_comm = get_tp_ca_communicator() + context = nullcontext() if ca_comm is None else ca_comm.capture() + with context: yield @@ -43,15 +66,15 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: TLDR: always assume this function modifies its input, but use the return value as the output. """ - from vllm.distributed.device_communicators.custom_all_reduce import ( - custom_all_reduce) + ca_comm = get_tp_ca_communicator() # Bypass the function if we are using only 1 GPU. if get_tensor_model_parallel_world_size() == 1: return input_ - out = custom_all_reduce(input_) - if out is not None: - return out + if ca_comm is not None: + out = ca_comm.custom_all_reduce(input_) + if out is not None: + return out pynccl_comm = get_tp_pynccl_communicator() if (pynccl_comm is not None and not pynccl_comm.disabled): pynccl_comm.all_reduce(input_) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 5d26254fb832a..30ee9d1f8a1e9 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -1,154 +1,42 @@ from contextlib import contextmanager -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import torch import torch.distributed as dist +from torch.distributed import ProcessGroup import vllm.envs as envs +from vllm.distributed.parallel_state import ( + get_local_rank, get_tensor_model_parallel_cpu_group) from vllm.logger import init_logger try: import pynvml from vllm._C import custom_ar + + @contextmanager + def _nvml(): + try: + pynvml.nvmlInit() + yield + finally: + pynvml.nvmlShutdown() + except ImportError: # For AMD GPUs custom_ar = None pynvml = None -logger = init_logger(__name__) + @contextmanager + def _nvml(): + try: + yield + finally: + pass -_CA_HANDLE: Optional["CustomAllreduce"] = None -_IS_CAPTURING = False -_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] - - -def init_custom_ar() -> None: - from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) - - global _CA_HANDLE - if _CA_HANDLE is not None: - return - rank = get_tensor_model_parallel_rank() - world_size = get_tensor_model_parallel_world_size() - if world_size == 1: - # No need to initialize custom allreduce for single GPU case. - return - - if world_size not in _SUPPORTED_WORLD_SIZES: - logger.warning( - "Custom allreduce is disabled due to an unsupported world size: " - "%d. Supported world sizes: %s. To silence this warning, specify" - " disable_custom_all_reduce=True explicitly.", world_size, - str(_SUPPORTED_WORLD_SIZES)) - return - num_dev = torch.cuda.device_count() - # note: num dev can be larger than world_size if we're only using - # first few GPUs - if num_dev < world_size: - logger.warning( - "Cannot test GPU P2P because not all GPUs are visible to the " - "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" - " is set.") - return - - # we only use a subset of GPUs here - # so we only need to check the nvlink connectivity of these GPUs - num_dev = world_size - # test nvlink first, this will filter out most of the cases - # where custom allreduce is not supported - cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES - if cuda_visible_devices: - device_ids = list(map(int, cuda_visible_devices.split(","))) - else: - device_ids = list(range(num_dev)) - # this checks hardware and driver support for NVLink - full_nvlink = _is_full_nvlink(device_ids) - if world_size > 2 and not full_nvlink: - logger.warning( - "Custom allreduce is disabled because it's not supported on more" - " than two PCIe-only GPUs. To silence this warning, specify" - " disable_custom_all_reduce=True explicitly.") - return - # test P2P capability, this checks software/cudaruntime support - # this is expensive to compute at the first time - # then we cache the result - if not _can_p2p(rank, world_size): - logger.warning( - "Custom allreduce is disabled because your platform lacks GPU P2P" - " capability or P2P test failed. To silence this warning, specify" - " disable_custom_all_reduce=True explicitly.") - return - _CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink) - - -def begin_capture() -> None: - global _IS_CAPTURING - _IS_CAPTURING = True - - -def end_capture() -> None: - global _IS_CAPTURING - _IS_CAPTURING = False - - -def is_capturing() -> bool: - return _IS_CAPTURING and _CA_HANDLE is not None - - -def get_handle() -> Optional["CustomAllreduce"]: - return _CA_HANDLE - - -def is_initialized() -> bool: - return _CA_HANDLE is not None - - -@contextmanager -def capture(): - try: - begin_capture() - yield - finally: - end_capture() - handle = get_handle() - if handle is not None: - handle.register_graph_buffers() - - -def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: - ca_handle = get_handle() - # when custom allreduce is disabled, this will be None - if ca_handle is None: - return None - if is_capturing(): - if torch.cuda.is_current_stream_capturing(): - if ca_handle.should_custom_ar(input): - return ca_handle.all_reduce_reg(input) - else: - if ca_handle.should_custom_ar(input): - # if warm up, mimic the allocation pattern - # since custom allreduce is out-of-place - return torch.empty_like(input) - else: - # note: outside of cuda graph context, - # custom allreduce incurs a cost of cudaMemcpy, which should - # be small(<=1% of overall latency) compared to the performance - # gains of using custom kernels - if ca_handle.should_custom_ar(input): - return ca_handle.all_reduce_unreg(input) - - return None - - -@contextmanager -def _nvml(): - try: - pynvml.nvmlInit() - yield - finally: - pynvml.nvmlShutdown() + +logger = init_logger(__name__) @_nvml() @@ -188,22 +76,112 @@ def _can_p2p(rank: int, world_size: int) -> bool: class CustomAllreduce: + _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + # max_size: max supported allreduce size def __init__(self, - rank, - world_size, - full_nvlink, + group: Optional[ProcessGroup] = None, + device: Optional[Union[int, str, torch.device]] = None, max_size=8192 * 1024) -> None: + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self._IS_CAPTURING = False + self.disabled = True + + if custom_ar is None: + # disable because of missing custom allreduce library + # e.g. in a non-cuda environment + return + + group = group or get_tensor_model_parallel_cpu_group() + self.group = group + + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "CustomAllreduce should be attached to a non-NCCL group.") + + rank = dist.get_rank(group=self.group) + world_size = dist.get_world_size(group=self.group) + if world_size == 1: + # No need to initialize custom allreduce for single GPU case. + return + + if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "Custom allreduce is disabled due to an unsupported world" + " size: %d. Supported world sizes: %s. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.", + world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) + return + + if device is None: + local_rank = get_local_rank() + device = torch.device(f"cuda:{local_rank}") + elif isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(torch.cuda.device_count())) + + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], + dtype=torch.int, + device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") + for _ in range(world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where custom allreduce is not supported + # this checks hardware and driver support for NVLink + full_nvlink = _is_full_nvlink(physical_device_ids) + if world_size > 2 and not full_nvlink: + logger.warning( + "Custom allreduce is disabled because it's not supported on" + " more than two PCIe-only GPUs. To silence this warning, " + "specify disable_custom_all_reduce=True explicitly.") + return + # test P2P capability, this checks software/cudaruntime support + # this is expensive to compute at the first time + # then we cache the result + if not _can_p2p(rank, world_size): + logger.warning( + "Custom allreduce is disabled because your platform lacks " + "GPU P2P capability or P2P test failed. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.") + return + + self.disabled = False # buffers memory are owned by this Python class and passed to C++ # meta data composes of two parts: meta data for synchronization # (256 bytes) and a temporary buffer for storing intermediate # allreduce results. self.meta = torch.zeros(custom_ar.meta_size() + max_size, dtype=torch.uint8, - device="cuda") + device=self.device) # This is a pre-registered IPC buffer. In eager mode, input tensors # are first copied into this buffer before allreduce is performed - self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda") + self.buffer = torch.empty(max_size, + dtype=torch.uint8, + device=self.device) # This is a buffer for storing the tuples of pointers pointing to # IPC buffers from all ranks. Each registered tuple has size of # 8*world_size bytes where world_size is at most 8. Allocating 8MB @@ -211,8 +189,9 @@ def __init__(self, # needs less than 10000 of registered tuples. self.rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, - device="cuda") + device=self.device) self.max_size = max_size + self.rank = rank self.world_size = world_size handles, offsets = self._get_ipc_meta(self.meta) self.full_nvlink = full_nvlink @@ -221,6 +200,21 @@ def __init__(self, self.full_nvlink) self.register_buffer(self.buffer) + @contextmanager + def capture(self): + """ + The main responsibility of this context manager is the + `register_graph_buffers` call at the end of the context. + It records all the buffer addresses used in the CUDA graph. + """ + try: + self._IS_CAPTURING = True + yield + finally: + self._IS_CAPTURING = False + if not self.disabled: + self.register_graph_buffers() + def _get_ipc_meta(self, inp: torch.Tensor): data = inp.untyped_storage()._share_cuda_() shard_data = ( @@ -230,14 +224,29 @@ def _get_ipc_meta(self, inp: torch.Tensor): return self._gather_ipc_meta(shard_data) def _gather_ipc_meta(self, shard_data): - all_data: List[Optional[Any]] = [None] * self.world_size - dist.all_gather_object(all_data, shard_data) + # Note: don't use `[[None]] * self.world_size` here + # because it will create a list of the same reference + all_data: List[Optional[Any]] = [[None] + for i in range(self.world_size)] + all_data[self.rank][0] = shard_data + + ranks = dist.get_process_group_ranks(group=self.group) + ranks.sort() + for i, rank in enumerate(ranks): + dist.broadcast_object_list(all_data[i], + src=rank, + group=self.group, + device="cpu") + + # we cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. handles = [] offsets = [] for i in range(len(all_data)): - handles.append(all_data[i][0]) # type: ignore - offsets.append(all_data[i][1]) # type: ignore + handles.append(all_data[i][0][0]) # type: ignore + offsets.append(all_data[i][0][1]) # type: ignore return handles, offsets def register_buffer(self, inp: torch.Tensor): @@ -269,8 +278,31 @@ def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out) return out + def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: + # when custom allreduce is disabled, this will be None + if self.disabled: + return None + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + if self.should_custom_ar(input): + return self.all_reduce_reg(input) + else: + if self.should_custom_ar(input): + # if warm up, mimic the allocation pattern + # since custom allreduce is out-of-place + return torch.empty_like(input) + else: + # note: outside of cuda graph context, + # custom allreduce incurs a cost of cudaMemcpy, which should + # be small(<=1% of overall latency) compared to the performance + # gains of using custom kernels + if self.should_custom_ar(input): + return self.all_reduce_unreg(input) + + return None + def close(self): - if self._ptr: + if not self.disabled and self._ptr: custom_ar.dispose(self._ptr) self._ptr = 0 diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 168d4cc2df8a6..092a0910329ad 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -96,8 +96,10 @@ def __init__( self.stream = torch.cuda.Stream() # A small all_reduce for warmup. - self.all_reduce(torch.zeros(1, device=device)) + data = torch.zeros(1, device=device) + self.all_reduce(data) self.stream.synchronize() + del data # by default it is disabled, e.g. in profiling models and prefill phase. # to use it, use under `with obj.change_state(enable=True)`, usually diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 5075da11bb1b8..d24104e3ed276 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -13,10 +13,13 @@ logger = init_logger(__name__) +_ENABLE_CUSTOM_ALL_REDUCE = True + # Tensor model parallel group that the current rank belongs to. _TP_DEVICE_GROUP: Optional[ProcessGroup] = None _TP_CPU_GROUP: Optional[ProcessGroup] = None _TP_PYNCCL_COMMUNICATOR = None +_TP_CA_COMMUNICATOR = None # Pipeline model parallel group that the current rank belongs to. _PP_DEVICE_GROUP: Optional[ProcessGroup] = None @@ -47,11 +50,21 @@ _LOCAL_RANK = -1 +def set_custom_all_reduce(enable: bool): + global _ENABLE_CUSTOM_ALL_REDUCE + _ENABLE_CUSTOM_ALL_REDUCE = enable + + def get_tp_pynccl_communicator(): global _TP_PYNCCL_COMMUNICATOR return _TP_PYNCCL_COMMUNICATOR +def get_tp_ca_communicator(): + global _TP_CA_COMMUNICATOR + return _TP_CA_COMMUNICATOR + + def get_local_rank(): global _LOCAL_RANK return _LOCAL_RANK @@ -100,6 +113,9 @@ def init_distributed_environment( if torch.cuda.is_available(): data = data.to(device=f"cuda:{local_rank}") torch.distributed.all_reduce(data) + if torch.cuda.is_available(): + torch.cuda.synchronize() + del data def initialize_model_parallel( @@ -149,7 +165,8 @@ def initialize_model_parallel( rank = torch.distributed.get_rank() # Build the tensor model-parallel groups. - global _TP_DEVICE_GROUP, _TP_CPU_GROUP, _TP_PYNCCL_COMMUNICATOR + global _TP_DEVICE_GROUP, _TP_CPU_GROUP + global _TP_PYNCCL_COMMUNICATOR, _TP_CA_COMMUNICATOR assert _TP_DEVICE_GROUP is None, ( "tensor model parallel group is already initialized") for i in range(num_tensor_model_parallel_groups): @@ -168,6 +185,15 @@ def initialize_model_parallel( device=_LOCAL_RANK, ) + # Initialize a custom fast all-reduce implementation. + if _ENABLE_CUSTOM_ALL_REDUCE: + from vllm.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce) + _TP_CA_COMMUNICATOR = CustomAllreduce( + group=_TP_CPU_GROUP, + device=_LOCAL_RANK, + ) + # Build the pipeline model-parallel groups. global _PP_DEVICE_GROUP global _PP_GLOBAL_RANKS diff --git a/vllm/test_utils.py b/vllm/test_utils.py index 0cf23e4bb7e75..addd8ec1c26c8 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -6,24 +6,24 @@ def init_test_distributed_environment( - pipeline_parallel_size: int, - tensor_parallel_size: int, + tp_size: int, + pp_size: int, rank: int, distributed_init_port: str, local_rank: int = -1, ) -> None: distributed_init_method = f"tcp://localhost:{distributed_init_port}" init_distributed_environment( - world_size=pipeline_parallel_size * tensor_parallel_size, + world_size=pp_size * tp_size, rank=rank, distributed_init_method=distributed_init_method, local_rank=local_rank) - ensure_model_parallel_initialized(tensor_parallel_size, - pipeline_parallel_size) + ensure_model_parallel_initialized(tp_size, pp_size) def multi_process_tensor_parallel( - tensor_parallel_size: int, + tp_size: int, + pp_size: int, test_target, ) -> None: # Using ray helps debugging the error when it failed @@ -32,10 +32,9 @@ def multi_process_tensor_parallel( distributed_init_port = get_open_port() refs = [] - for rank in range(tensor_parallel_size): + for rank in range(tp_size * pp_size): refs.append( - test_target.remote(tensor_parallel_size, rank, - distributed_init_port)) + test_target.remote(tp_size, pp_size, rank, distributed_init_port)) ray.get(refs) ray.shutdown() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 21d76fd531e49..f46b475bdc2db 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -12,8 +12,7 @@ ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict -from vllm.distributed.communication_op import graph_capture_mode -from vllm.distributed.device_communicators import custom_all_reduce +from vllm.distributed.communication_op import graph_capture, graph_mode from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -942,13 +941,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] - # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce - # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use - # either custom all-reduce kernel or pynccl. When not using CUDA - # graph, we use either custom all-reduce kernel or PyTorch NCCL. - # We always prioritize using custom all-reduce kernel but fall back - # to PyTorch or pynccl if it is disabled or not supported. - with custom_all_reduce.capture(): + with graph_capture(): # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): @@ -1040,7 +1033,7 @@ def capture( # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with graph_capture_mode(): + with graph_mode(): self.model( input_ids, positions, @@ -1055,7 +1048,7 @@ def capture( # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117 - with graph_capture_mode(): + with graph_mode(): hidden_states = self.model( input_ids, positions, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index e4fbc877b8c9f..82cf58101a95b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -11,9 +11,8 @@ VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.distributed.device_communicators.custom_all_reduce import ( - init_custom_ar) + init_distributed_environment, + set_custom_all_reduce) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput @@ -302,16 +301,14 @@ def init_worker_distributed_environment( local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - # Initialize a custom fast all-reduce implementation. - if not parallel_config.disable_custom_all_reduce: - init_custom_ar() - def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. From 350f9e107f0c00e59be1b970f96395494ed68b48 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 13 May 2024 22:50:09 +0800 Subject: [PATCH 032/124] [CI/Build] Move `test_utils.py` to `tests/utils.py` (#4425) Since #4335 was merged, I've noticed that the definition of ServerRunner in the tests is the same as in the test for OpenAI API. I have moved the class to the test utilities to avoid code duplication. (Although it only has been repeated twice so far, I will add another similar test suite in #4200 which would duplicate the code a third time) Also, I have moved the test utilities file (test_utils.py) to under the test directory (tests/utils.py), since none of its code is actually used in the main package. Note that I have added __init__.py to each test subpackage and updated the ray.init() call in the test utilities file in order to relative import tests/utils.py. --- .buildkite/test-pipeline.yaml | 24 +++-- tests/async_engine/__init__.py | 0 tests/async_engine/test_openapi_server_ray.py | 51 +---------- tests/basic_correctness/__init__.py | 0 tests/core/block/e2e/__init__.py | 0 tests/core/block/e2e/conftest.py | 3 +- tests/distributed/__init__.py | 0 .../test_basic_distributed_correctness.py | 6 +- tests/distributed/test_comm_ops.py | 5 +- tests/distributed/test_custom_all_reduce.py | 5 +- tests/engine/__init__.py | 0 tests/engine/output_processor/__init__.py | 0 .../output_processor/test_multi_step.py | 3 +- tests/entrypoints/__init__.py | 0 tests/entrypoints/test_openai_server.py | 47 +--------- tests/kernels/__init__.py | 0 tests/kernels/test_activation.py | 3 +- tests/kernels/test_attention.py | 3 +- tests/kernels/test_pos_encoding.py | 3 +- tests/metrics/__init__.py | 0 tests/model_executor/__init__.py | 0 tests/models/__init__.py | 0 tests/models/test_gptq_marlin.py | 3 +- tests/models/test_marlin.py | 3 +- tests/models/test_mistral.py | 2 +- tests/prefix_caching/__init__.py | 0 tests/quantization/__init__.py | 0 tests/samplers/__init__.py | 0 tests/samplers/test_logprobs.py | 3 +- tests/spec_decode/e2e/conftest.py | 3 +- tests/tensorizer_loader/test_tensorizer.py | 3 +- tests/test_sequence.py | 3 +- tests/utils.py | 89 +++++++++++++++++++ vllm/test_utils.py | 40 --------- 34 files changed, 138 insertions(+), 164 deletions(-) create mode 100644 tests/async_engine/__init__.py create mode 100644 tests/basic_correctness/__init__.py create mode 100644 tests/core/block/e2e/__init__.py create mode 100644 tests/distributed/__init__.py create mode 100644 tests/engine/__init__.py create mode 100644 tests/engine/output_processor/__init__.py create mode 100644 tests/entrypoints/__init__.py create mode 100644 tests/kernels/__init__.py create mode 100644 tests/metrics/__init__.py create mode 100644 tests/model_executor/__init__.py create mode 100644 tests/models/__init__.py create mode 100644 tests/prefix_caching/__init__.py create mode 100644 tests/quantization/__init__.py create mode 100644 tests/samplers/__init__.py create mode 100644 tests/utils.py delete mode 100644 vllm/test_utils.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 2eeba904a209d..4feea786f38ba 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -24,28 +24,26 @@ steps: command: pytest -v -s core - label: Distributed Comm Ops Test - command: pytest -v -s test_comm_ops.py - working_dir: "/vllm-workspace/tests/distributed" + command: pytest -v -s distributed/test_comm_ops.py + working_dir: "/vllm-workspace/tests" num_gpus: 2 - label: Distributed Tests - working_dir: "/vllm-workspace/tests/distributed" - - num_gpus: 2 # only support 1 or 2 for now. + working_dir: "/vllm-workspace/tests" + num_gpus: 2 mirror_hardwares: [amd] - commands: - - pytest -v -s test_pynccl_library.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py + - pytest -v -s distributed/test_pynccl_library.py + - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_chunked_prefill_distributed.py - label: Distributed Tests (Multiple Groups) - working_dir: "/vllm-workspace/tests/distributed" + working_dir: "/vllm-workspace/tests" num_gpus: 4 commands: - - pytest -v -s test_pynccl.py + - pytest -v -s distributed/test_pynccl.py - label: Engine Test #mirror_hardwares: [amd] diff --git a/tests/async_engine/__init__.py b/tests/async_engine/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index 4b97af88012b9..ace4c53916c71 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -1,61 +1,16 @@ -# imports for guided decoding tests -import os -import subprocess -import sys -import time - import openai # use the official client for correctness check import pytest # using Ray for overall ease of process management, parallel requests, # and debugging. import ray -import requests -MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds +from ..utils import ServerRunner + # any model with a chat template should work here MODEL_NAME = "facebook/opt-125m" -@ray.remote(num_gpus=1) -class ServerRunner: - - def __init__(self, args): - env = os.environ.copy() - env["PYTHONUNBUFFERED"] = "1" - self.proc = subprocess.Popen( - ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, - env=env, - stdout=sys.stdout, - stderr=sys.stderr, - ) - self._wait_for_server() - - def ready(self): - return True - - def _wait_for_server(self): - # run health check - start = time.time() - while True: - try: - if requests.get( - "http://localhost:8000/health").status_code == 200: - break - except Exception as err: - if self.proc.poll() is not None: - raise RuntimeError("Server exited unexpectedly.") from err - - time.sleep(0.5) - if time.time() - start > MAX_SERVER_START_WAIT_S: - raise RuntimeError( - "Server failed to start in time.") from err - - def __del__(self): - if hasattr(self, "proc"): - self.proc.terminate() - - -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def server(): ray.init() server_runner = ServerRunner.remote([ diff --git a/tests/basic_correctness/__init__.py b/tests/basic_correctness/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/core/block/e2e/__init__.py b/tests/core/block/e2e/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/core/block/e2e/conftest.py b/tests/core/block/e2e/conftest.py index 1d99cb5d32219..b0d62c8993d3f 100644 --- a/tests/core/block/e2e/conftest.py +++ b/tests/core/block/e2e/conftest.py @@ -1,9 +1,10 @@ import pytest -from tests.conftest import cleanup from vllm import LLM from vllm.model_executor.utils import set_random_seed +from ....conftest import cleanup + @pytest.fixture def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, diff --git a/tests/distributed/__init__.py b/tests/distributed/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 527452630c9f5..d63f015511ada 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -4,10 +4,12 @@ variables. Run: ```sh +cd $VLLM_PATH/tests + TEST_DIST_MODEL=facebook/opt-125m pytest \ - test_basic_distributed_correctness.py + distributed/test_basic_distributed_correctness.py TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ - test_basic_distributed_correctness.py + distributed/test_basic_distributed_correctness.py ``` """ import os diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index a4423bbfddf46..53654dc40d10d 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -11,8 +11,9 @@ from vllm.distributed import (broadcast_tensor_dict, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.test_utils import (init_test_distributed_environment, - multi_process_tensor_parallel) + +from ..utils import (init_test_distributed_environment, + multi_process_tensor_parallel) @ray.remote(num_gpus=1, max_calls=1) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index bdca031e39be1..630b25a3b6132 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -10,8 +10,9 @@ graph_capture, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, get_tp_ca_communicator) -from vllm.test_utils import (init_test_distributed_environment, - multi_process_tensor_parallel) + +from ..utils import (init_test_distributed_environment, + multi_process_tensor_parallel) random.seed(42) test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)] diff --git a/tests/engine/__init__.py b/tests/engine/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/engine/output_processor/__init__.py b/tests/engine/output_processor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/engine/output_processor/test_multi_step.py b/tests/engine/output_processor/test_multi_step.py index 2bf4bf69da203..4f32a622546f0 100644 --- a/tests/engine/output_processor/test_multi_step.py +++ b/tests/engine/output_processor/test_multi_step.py @@ -4,7 +4,6 @@ import pytest from transformers import PreTrainedTokenizer -from tests.core.utils import create_seq_group from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor from vllm.engine.output_processor.stop_checker import StopChecker @@ -14,6 +13,8 @@ from vllm.transformers_utils.detokenizer import Detokenizer from vllm.utils import Counter +from ...core.utils import create_seq_group + @pytest.mark.parametrize("seq_output_len", [128]) @pytest.mark.parametrize("num_new_tokens", [1, 12]) diff --git a/tests/entrypoints/__init__.py b/tests/entrypoints/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index c22ac4507658b..ee2f034fd2c46 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -1,10 +1,6 @@ # imports for guided decoding tests import json -import os import re -import subprocess -import sys -import time import jsonschema import openai # use the official client for correctness check @@ -12,7 +8,6 @@ # using Ray for overall ease of process management, parallel requests, # and debugging. import ray -import requests import torch # downloading lora to test lora requests from huggingface_hub import snapshot_download @@ -20,7 +15,8 @@ from vllm.transformers_utils.tokenizer import get_tokenizer -MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds +from ..utils import ServerRunner + # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct" @@ -78,45 +74,6 @@ pytestmark = pytest.mark.asyncio -@ray.remote(num_gpus=1) -class ServerRunner: - - def __init__(self, args): - env = os.environ.copy() - env["PYTHONUNBUFFERED"] = "1" - self.proc = subprocess.Popen( - ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, - env=env, - stdout=sys.stdout, - stderr=sys.stderr, - ) - self._wait_for_server() - - def ready(self): - return True - - def _wait_for_server(self): - # run health check - start = time.time() - while True: - try: - if requests.get( - "http://localhost:8000/health").status_code == 200: - break - except Exception as err: - if self.proc.poll() is not None: - raise RuntimeError("Server exited unexpectedly.") from err - - time.sleep(0.5) - if time.time() - start > MAX_SERVER_START_WAIT_S: - raise RuntimeError( - "Server failed to start in time.") from err - - def __del__(self): - if hasattr(self, "proc"): - self.proc.terminate() - - @pytest.fixture(scope="session") def zephyr_lora_files(): return snapshot_download(repo_id=LORA_NAME) diff --git a/tests/kernels/__init__.py b/tests/kernels/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 86ecc6412c648..a624c4ca9ee62 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -2,11 +2,12 @@ import pytest import torch -from allclose_default import get_default_atol, get_default_rtol from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul, NewGELU, SiluAndMul) +from .allclose_default import get_default_atol, get_default_rtol + DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 4096, 5120, 13824] # Arbitrary values for testing diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 28496f187d466..fdf313262ca97 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -3,13 +3,14 @@ import pytest import torch -from allclose_default import get_default_atol, get_default_rtol from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from vllm import _custom_ops as ops from vllm.utils import get_max_shared_memory_bytes, is_hip +from .allclose_default import get_default_atol, get_default_rtol + FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index bf1856972cf33..18c8e351aa778 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -3,10 +3,11 @@ import pytest import torch -from allclose_default import get_default_atol, get_default_rtol from vllm.model_executor.layers.rotary_embedding import get_rope +from .allclose_default import get_default_atol, get_default_rtol + IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] HEAD_SIZES = [64, 80, 96, 112, 128, 256] diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/model_executor/__init__.py b/tests/model_executor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py index b1c2b88bc99af..db55d4488a374 100644 --- a/tests/models/test_gptq_marlin.py +++ b/tests/models/test_gptq_marlin.py @@ -13,9 +13,10 @@ import pytest import torch -from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from .utils import check_logprobs_close + os.environ["TOKENIZERS_PARALLELISM"] = "true" MAX_MODEL_LEN = 1024 diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index fa846d43d0e88..37c1664afec55 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -15,9 +15,10 @@ import pytest import torch -from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from .utils import check_logprobs_close + capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] marlin_not_supported = (capability < diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 33d28da85d9e7..d0a5bfbfcd922 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -4,7 +4,7 @@ """ import pytest -from tests.models.utils import check_logprobs_close +from .utils import check_logprobs_close MODELS = [ "mistralai/Mistral-7B-Instruct-v0.1", diff --git a/tests/prefix_caching/__init__.py b/tests/prefix_caching/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/quantization/__init__.py b/tests/quantization/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/samplers/__init__.py b/tests/samplers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 57d6d2a410ee5..40d054cd472b8 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -1,9 +1,10 @@ import pytest import torch -from tests.conftest import VllmRunner from vllm import SamplingParams +from ..conftest import VllmRunner + MODELS = ["facebook/opt-125m"] diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index eda7293ea7cee..da8b92711380e 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -9,7 +9,6 @@ from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlInit) -from tests.conftest import cleanup from vllm import LLM from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -21,6 +20,8 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, random_uuid +from ...conftest import cleanup + class AsyncLLM: """AsyncLLM diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index df1db4e6c4001..ad4748c5ebe96 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -9,12 +9,13 @@ import ray import torch -from tests.entrypoints.test_openai_server import ServerRunner from vllm import SamplingParams from vllm.model_executor.model_loader.tensorizer import ( EncryptionParams, TensorizerConfig, TensorSerializer, is_vllm_serialized_tensorizer, load_with_tensorizer, open_stream) +from ..utils import ServerRunner + prompts = [ "Hello, my name is", "The president of the United States is", diff --git a/tests/test_sequence.py b/tests/test_sequence.py index b8ea1f6b77200..3136402518b9f 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,9 +1,10 @@ import pytest -from tests.core.utils import create_dummy_prompt from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput, SequenceData, SequenceOutput) +from .core.utils import create_dummy_prompt + @pytest.fixture def sample_outputs(): diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000000000..689d8c8c5ba8a --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,89 @@ +import os +import subprocess +import sys +import time + +import ray +import requests + +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.utils import get_open_port + +# Path to root of repository so that utilities can be imported by ray workers +VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir)) + + +@ray.remote(num_gpus=1) +class ServerRunner: + MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds + + def __init__(self, args): + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + self.proc = subprocess.Popen( + ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + self._wait_for_server() + + def ready(self): + return True + + def _wait_for_server(self): + # run health check + start = time.time() + while True: + try: + if requests.get( + "http://localhost:8000/health").status_code == 200: + break + except Exception as err: + if self.proc.poll() is not None: + raise RuntimeError("Server exited unexpectedly.") from err + + time.sleep(0.5) + if time.time() - start > self.MAX_SERVER_START_WAIT_S: + raise RuntimeError( + "Server failed to start in time.") from err + + def __del__(self): + if hasattr(self, "proc"): + self.proc.terminate() + + +def init_test_distributed_environment( + tp_size: int, + pp_size: int, + rank: int, + distributed_init_port: str, + local_rank: int = -1, +) -> None: + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + init_distributed_environment( + world_size=pp_size * tp_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=local_rank) + ensure_model_parallel_initialized(tp_size, pp_size) + + +def multi_process_tensor_parallel( + tp_size: int, + pp_size: int, + test_target, +) -> None: + # Using ray helps debugging the error when it failed + # as compared to multiprocessing. + ray.init(runtime_env={"working_dir": VLLM_PATH}) + + distributed_init_port = get_open_port() + refs = [] + for rank in range(tp_size * pp_size): + refs.append( + test_target.remote(tp_size, pp_size, rank, distributed_init_port)) + ray.get(refs) + + ray.shutdown() diff --git a/vllm/test_utils.py b/vllm/test_utils.py deleted file mode 100644 index addd8ec1c26c8..0000000000000 --- a/vllm/test_utils.py +++ /dev/null @@ -1,40 +0,0 @@ -import ray - -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.utils import get_open_port - - -def init_test_distributed_environment( - tp_size: int, - pp_size: int, - rank: int, - distributed_init_port: str, - local_rank: int = -1, -) -> None: - distributed_init_method = f"tcp://localhost:{distributed_init_port}" - init_distributed_environment( - world_size=pp_size * tp_size, - rank=rank, - distributed_init_method=distributed_init_method, - local_rank=local_rank) - ensure_model_parallel_initialized(tp_size, pp_size) - - -def multi_process_tensor_parallel( - tp_size: int, - pp_size: int, - test_target, -) -> None: - # Using ray helps debugging the error when it failed - # as compared to multiprocessing. - ray.init() - - distributed_init_port = get_open_port() - refs = [] - for rank in range(tp_size * pp_size): - refs.append( - test_target.remote(tp_size, pp_size, rank, distributed_init_port)) - ray.get(refs) - - ray.shutdown() From e7c46b9527c9a50253657fd0078a0b1f23560ce4 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Mon, 13 May 2024 23:50:44 +0900 Subject: [PATCH 033/124] [Scheduler] Warning upon preemption and Swapping (#4647) Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> --- docs/source/models/performance.rst | 19 ++++++++++ tests/basic_correctness/test_preemption.py | 44 +++++++++++++++++++++- tests/conftest.py | 16 ++++++++ tests/core/test_scheduler.py | 1 + vllm/core/scheduler.py | 18 +++++++++ vllm/engine/llm_engine.py | 4 +- vllm/engine/metrics.py | 9 ++++- 7 files changed, 108 insertions(+), 3 deletions(-) diff --git a/docs/source/models/performance.rst b/docs/source/models/performance.rst index 589fce21056c2..d8750ddc34e8e 100644 --- a/docs/source/models/performance.rst +++ b/docs/source/models/performance.rst @@ -3,6 +3,25 @@ Performance and Tuning ====================== +Preemption +---------- +Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests. +The vLLM can preempt requests to free up KV cache space for other requests. Preempted requests are recomputed when sufficient KV cache space becomes +available again. When this occurs, the following warning is printed: + +``` +WARNING 05-09 00:49:33 scheduler.py:1057] Sequence group 0 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_cumulative_preemption_cnt=1 +``` + +While this mechanism ensures system robustness, preemption and recomputation can adversely affect end-to-end latency. +If you frequently encounter preemptions from the vLLM engine, consider the following actions: + +- Increase `gpu_memory_utilization`. The vLLM pre-allocates GPU cache by using gpu_memory_utilization% of memory. By increasing this utilization, you can provide more KV cache space. +- Decrease `max_num_seqs` or `max_num_batched_tokens`. This can reduce the number of concurrent requests in a batch, thereby requiring less KV cache space. +- Increase `tensor_parallel_size`. This approach shards model weights, so each GPU has more memory available for KV cache. + +You can also monitor the number of preemption requests through Prometheus metrics exposed by the vLLM. Additionally, you can log the cumulative number of preemption requests by setting disable_log_stats=False. + Chunked Prefill --------------- vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests. diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index ffb0717b3bfdb..29a4c39cd25a1 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -6,6 +6,7 @@ pytest tests/basic_correctness/test_preemption.py`. """ import pytest +from prometheus_client import REGISTRY from vllm import SamplingParams from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, @@ -71,6 +72,7 @@ def test_chunked_prefill_recompute( @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [96]) def test_preemption( + caplog_vllm, hf_runner, vllm_runner, example_prompts, @@ -87,10 +89,13 @@ def test_preemption( vllm_model = vllm_runner( model, dtype=dtype, + disable_log_stats=False, ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < ARTIFICIAL_PREEMPTION_MAX_CNT) + total_preemption = ( + vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) del vllm_model for i in range(len(example_prompts)): @@ -100,6 +105,20 @@ def test_preemption( f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + assert ("is preempted by PreemptionMode.RECOMPUTE mode because there " + "is not enough KV cache space." in caplog_vllm.text) + # Ensure the count bucket of request-level histogram metrics matches + # the number of requests as a simple sanity check to ensure metrics are + # generated + preemption_metrics = None + for m in REGISTRY.collect(): + if m.name == "vllm:num_preemptions": + preemption_metrics = m + assert preemption_metrics is not None + total_recorded_preemption = 0 + for sample in preemption_metrics.samples: + total_recorded_preemption += sample.value + assert total_preemption == total_recorded_preemption @pytest.mark.parametrize("model", MODELS) @@ -107,6 +126,7 @@ def test_preemption( @pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("beam_width", [4]) def test_swap( + caplog_vllm, hf_runner, vllm_runner, example_prompts, @@ -122,11 +142,18 @@ def test_swap( max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, swap_space=10) + vllm_model = vllm_runner( + model, + dtype=dtype, + swap_space=10, + disable_log_stats=False, + ) vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width, max_tokens) assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < ARTIFICIAL_PREEMPTION_MAX_CNT) + total_preemption = ( + vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) del vllm_model for i in range(len(example_prompts)): @@ -138,6 +165,21 @@ def test_swap( f"Test{i} output{j}:\nHF: {hf_output_ids}\n" f"vLLM: {vllm_output_ids}") + assert ("is preempted by PreemptionMode.SWAP mode because there " + "is not enough KV cache space." in caplog_vllm.text) + # Ensure the count bucket of request-level histogram metrics matches + # the number of requests as a simple sanity check to ensure metrics are + # generated + preemption_metrics = None + for m in REGISTRY.collect(): + if m.name == "vllm:num_preemptions": + preemption_metrics = m + assert preemption_metrics is not None + total_recorded_preemption = 0 + for sample in preemption_metrics.samples: + total_recorded_preemption += sample.value + assert total_preemption == total_recorded_preemption + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) diff --git a/tests/conftest.py b/tests/conftest.py index b8117a19c75d9..999ace2c3c699 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -499,3 +499,19 @@ def get_tokenizer_pool_config(tokenizer_group_type): pool_type="ray", extra_config={}) raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") + + +@pytest.fixture() +def temporary_enable_log_propagate(): + import logging + logger = logging.getLogger("vllm") + logger.propagate = True + yield + logger.propagate = False + + +@pytest.fixture() +def caplog_vllm(temporary_enable_log_propagate, caplog): + # To capture vllm log, we should enable propagate=True temporarily + # because caplog depends on logs propagated to the root logger. + yield caplog diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 6bcabc4f95fa9..07fc8731e1847 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -180,6 +180,7 @@ def test_scheduler_schedule_preempt_abort(): and not out.blocks_to_swap_out) assert len(seq_group_meta) == 1 assert scheduler.get_num_unfinished_seq_groups() == 2 + assert out.preempted == 1 # Abort seq group a. Re-schedule seq group b prompt with recomputation. scheduler.abort_seq_group("1") diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index fb6e985b2f31c..fbde27f998233 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -129,6 +129,7 @@ class SchedulerOutputs: num_lookahead_slots: int # The number of requests in the running queue running_queue_size: int + preempted: int def __post_init__(self): # Swap in and swap out should never happen at the same time. @@ -310,6 +311,7 @@ def __init__( self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT if self.enable_artificial_preemption else 0) + self.num_cumulative_preemption: int = 0 @property def lora_enabled(self) -> bool: @@ -785,6 +787,8 @@ def _schedule_default(self) -> SchedulerOutputs: # Update swapped requests. self.swapped = remaining_swapped self.swapped.extend(running_scheduled.swapped_out) + preempted = (len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)) # There should be no prefill from running queue because this policy # doesn't allow chunked prefills. @@ -804,6 +808,7 @@ def _schedule_default(self) -> SchedulerOutputs: swapped_in.infeasible_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, running_queue_size=len(self.running), + preempted=preempted, ) def _schedule_chunked_prefill(self): @@ -891,6 +896,8 @@ def _schedule_chunked_prefill(self): ignored_seq_groups=prefills.ignored_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, running_queue_size=len(self.running), + preempted=(len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)), ) def _schedule(self) -> SchedulerOutputs: @@ -1057,6 +1064,17 @@ def _preempt( preemption_mode = PreemptionMode.RECOMPUTE else: preemption_mode = PreemptionMode.SWAP + + if self.num_cumulative_preemption % 50 == 0: + logger.warning( + "Sequence group %s is preempted by %s mode because there is " + "not enough KV cache space. This can affect the end-to-end " + "performance. Increase gpu_memory_utilization or " + "tensor_parallel_size to provide more KV cache memory. " + "total_num_cumulative_preemption=%d", seq_group.request_id, + preemption_mode, self.num_cumulative_preemption + 1) + self.num_cumulative_preemption += 1 + if preemption_mode == PreemptionMode.RECOMPUTE: self._preempt_by_recompute(seq_group) elif preemption_mode == PreemptionMode.SWAP: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 46fa41030b4a1..e258a3f4afd54 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -737,6 +737,8 @@ def _get_stats( num_generation_tokens_iter = 0 time_to_first_tokens_iter: List[float] = [] time_per_output_tokens_iter: List[float] = [] + num_preemption_iter = (0 if scheduler_outputs is None else + scheduler_outputs.preempted) # Request stats # Latency @@ -830,7 +832,6 @@ def _get_stats( return Stats( now=now, - # System stats # Scheduler State num_running_sys=num_running_sys, @@ -846,6 +847,7 @@ def _get_stats( time_to_first_tokens_iter=time_to_first_tokens_iter, time_per_output_tokens_iter=time_per_output_tokens_iter, spec_decode_metrics=spec_decode_metrics, + num_preemption_iter=num_preemption_iter, # Request stats # Latency diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 3c4aac91549a9..ae7ae144bc04f 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -61,6 +61,10 @@ def __init__(self, labelnames: List[str], max_model_len: int): labelnames=labelnames) # Iteration stats + self.counter_num_preemption = Counter( + name="vllm:num_preemptions_total", + documentation="Cumulative number of preemption from the engine.", + labelnames=labelnames) self.counter_prompt_tokens = Counter( name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", @@ -181,6 +185,7 @@ class Stats: num_generation_tokens_iter: int time_to_first_tokens_iter: List[float] time_per_output_tokens_iter: List[float] + num_preemption_iter: int # Request stats (should have _requests suffix) # Latency @@ -244,6 +249,8 @@ def _log_prometheus(self, stats: Stats) -> None: stats.cpu_cache_usage_sys) # Iteration level data + self._log_counter(self.metrics.counter_num_preemption, + stats.num_preemption_iter) self._log_counter(self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter) self._log_counter(self.metrics.counter_generation_tokens, @@ -336,7 +343,7 @@ def log(self, stats: Stats) -> None: "Avg generation throughput: %.1f tokens/s, " "Running: %d reqs, Swapped: %d reqs, " "Pending: %d reqs, GPU KV cache usage: %.1f%%, " - "CPU KV cache usage: %.1f%%", + "CPU KV cache usage: %.1f%%.", prompt_throughput, generation_throughput, stats.num_running_sys, From 0fca3cdcf265cd375bca684d951702b6b7adf65a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 13 May 2024 10:47:25 -0700 Subject: [PATCH 034/124] [Misc] Enhance attention selector (#4751) --- tests/worker/test_model_runner.py | 1 - vllm/attention/__init__.py | 4 +- vllm/attention/backends/abstract.py | 5 +- vllm/attention/backends/flash_attn.py | 13 +++-- vllm/attention/backends/flashinfer.py | 33 +++++++---- vllm/attention/backends/rocm_flash_attn.py | 16 ++--- vllm/attention/backends/torch_sdpa.py | 28 +++++---- vllm/attention/backends/xformers.py | 12 ++-- vllm/attention/layer.py | 19 +++++- vllm/attention/selector.py | 28 +++++++-- vllm/model_executor/model_loader/__init__.py | 19 +++--- vllm/model_executor/model_loader/loader.py | 61 +++++++++++++------- vllm/model_executor/models/arctic.py | 16 ++++- vllm/model_executor/models/baichuan.py | 29 +++++++--- vllm/model_executor/models/bloom.py | 15 +++-- vllm/model_executor/models/chatglm.py | 20 +++++-- vllm/model_executor/models/commandr.py | 14 ++++- vllm/model_executor/models/dbrx.py | 17 ++++-- vllm/model_executor/models/decilm.py | 4 +- vllm/model_executor/models/deepseek.py | 16 ++++- vllm/model_executor/models/falcon.py | 15 +++-- vllm/model_executor/models/gemma.py | 14 +++-- vllm/model_executor/models/gpt2.py | 16 +++-- vllm/model_executor/models/gpt_bigcode.py | 14 +++-- vllm/model_executor/models/gpt_j.py | 20 +++++-- vllm/model_executor/models/gpt_neox.py | 16 +++-- vllm/model_executor/models/internlm2.py | 13 ++++- vllm/model_executor/models/jais.py | 12 +++- vllm/model_executor/models/llama.py | 17 ++++-- vllm/model_executor/models/llava.py | 6 +- vllm/model_executor/models/minicpm.py | 13 ++++- vllm/model_executor/models/mixtral.py | 13 ++++- vllm/model_executor/models/mixtral_quant.py | 31 ++++++---- vllm/model_executor/models/mpt.py | 18 ++++-- vllm/model_executor/models/olmo.py | 14 +++-- vllm/model_executor/models/opt.py | 16 +++-- vllm/model_executor/models/orion.py | 13 ++++- vllm/model_executor/models/phi.py | 16 +++-- vllm/model_executor/models/qwen.py | 15 ++++- vllm/model_executor/models/qwen2.py | 14 +++-- vllm/model_executor/models/qwen2_moe.py | 16 ++++- vllm/model_executor/models/stablelm.py | 14 +++-- vllm/model_executor/models/starcoder2.py | 18 +++++- vllm/model_executor/models/xverse.py | 14 +++-- vllm/worker/cache_engine.py | 14 ++++- vllm/worker/cpu_model_runner.py | 15 +++-- vllm/worker/cpu_worker.py | 10 +++- vllm/worker/embedding_model_runner.py | 1 - vllm/worker/model_runner.py | 15 +++-- 49 files changed, 573 insertions(+), 220 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 3e3d2e3f5c53d..c2d1c5769619b 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -307,7 +307,6 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): assert len(attn_metadata.slot_mapping) == len(input_tokens) assert len(input_positions) == len(input_tokens) - assert attn_metadata.kv_cache_dtype == "auto" assert attn_metadata.num_prefills == prefill_batch_size if enforce_eager: assert attn_metadata.num_decode_tokens == decode_batch_size diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 7636b34a16fed..088f48def7668 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -5,9 +5,9 @@ from vllm.attention.selector import get_attn_backend __all__ = [ + "Attention", "AttentionBackend", "AttentionMetadata", - "Attention", - "get_attn_backend", "AttentionMetadataPerStage", + "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 64ccb309a0480..98d70fcab1a18 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -94,8 +94,6 @@ class AttentionMetadata(Generic[T]): # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor - # The kv cache's data type. - kv_cache_dtype: str def __post_init__(self): if self.num_prefill_tokens > 0: @@ -116,6 +114,7 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: raise NotImplementedError @@ -127,6 +126,6 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4bad226512b69..f59715bd76ede 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -140,16 +140,18 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -167,7 +169,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata[FlashAttentionMetadata], - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -196,8 +198,7 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -264,7 +265,7 @@ def forward( decode_meta.block_tables, decode_meta.seq_lens_tensor, decode_meta.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 36e162671f944..92d0fe0487516 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -149,20 +149,33 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: - if sliding_window is not None: - raise ValueError("Sliding window is not supported in FlashInfer.") - self.sliding_window = (-1, -1) - self.alibi_slopes = alibi_slopes - self.scale = scale self.num_heads = num_heads self.head_size = head_size + self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is not None: + raise ValueError("Sliding window is not supported in FlashInfer.") + self.sliding_window = (-1, -1) + self.kv_cache_dtype = kv_cache_dtype - def forward(self, query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[FlashInferMetadata], - kv_scale: float): + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata[FlashInferMetadata], + kv_scale: float = 1.0, + ) -> torch.Tensor: + assert kv_scale == 1.0 num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -183,7 +196,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, kv_cache[:, 0], kv_cache[:, 1], attn_metadata.slot_mapping.flatten(), - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, ) if prefill_meta := attn_metadata.prefill_metadata: diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 8fc1af1aa1e1c..539585b46c7aa 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -138,25 +138,27 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {supported_head_sizes}.") self.use_naive_attn = False # NOTE: Allow for switching between Triton and CK. Defaulting to triton. @@ -229,7 +231,7 @@ def forward( key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, kv_scale, ) @@ -323,7 +325,7 @@ def forward( decode_meta.block_tables, decode_meta.seq_lens_tensor, decode_meta.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c29218dfd0cfc..2dd72a00c6e30 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -83,26 +83,32 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = sliding_window if alibi_slopes is not None: - assert len(alibi_slopes) == num_heads alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes - self.need_mask = (self.alibi_slopes is not None - or self.sliding_window is not None) + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {supported_head_sizes}.") + if kv_cache_dtype != "auto": + raise NotImplementedError( + "Torch SDPA backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") def forward( self, @@ -111,7 +117,7 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: TorchSDPAMetadata, # type: ignore - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -124,6 +130,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + assert kv_scale == 1.0 num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -136,8 +143,7 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + self.kv_cache_dtype, kv_scale) if attn_metadata.is_prompt: assert attn_metadata.seq_lens is not None @@ -195,7 +201,7 @@ def forward( attn_metadata.block_tables, attn_metadata.seq_lens_tensor, attn_metadata.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 2a9150dea5875..cb2028553461f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -149,15 +149,17 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = sliding_window if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -175,7 +177,7 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata[XFormersMetadata], - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -188,7 +190,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) @@ -203,8 +204,7 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -262,7 +262,7 @@ def forward( decode_meta.block_tables, decode_meta.seq_lens_tensor, decode_meta.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ee7be26c0876c..8a872dba8c877 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -7,6 +7,7 @@ from vllm.attention.backends.abstract import (AttentionMetadata, AttentionMetadataPerStage) from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig class Attention(nn.Module): @@ -29,10 +30,24 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() - self.backend = get_attn_backend(torch.get_default_dtype()) - impl_cls = self.backend.get_impl_cls() + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + if num_kv_heads is None: + num_kv_heads = num_heads + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + dtype = torch.get_default_dtype() + attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads, + sliding_window, dtype, kv_cache_dtype, + block_size) + impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index f4446bac6b8d2..06f99718a4dee 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,6 +1,6 @@ import enum from functools import lru_cache -from typing import Type +from typing import Optional, Type import torch @@ -21,8 +21,18 @@ class _Backend(enum.Enum): @lru_cache(maxsize=None) -def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: - backend = _which_attn_to_use(dtype) +def get_attn_backend( + num_heads: int, + head_size: int, + num_kv_heads: int, + sliding_window: Optional[int], + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, +) -> Type[AttentionBackend]: + backend = _which_attn_to_use(num_heads, head_size, num_kv_heads, + sliding_window, dtype, kv_cache_dtype, + block_size) if backend == _Backend.FLASH_ATTN: logger.info("Using FlashAttention-2 backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 @@ -44,14 +54,22 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: return TorchSDPABackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") - logger.warning("Eager mode is enforced for the Flashinfer backend. ") + logger.warning("Eager mode is enforced for the Flashinfer backend.") from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend else: raise ValueError("Invalid attention backend.") -def _which_attn_to_use(dtype: torch.dtype) -> _Backend: +def _which_attn_to_use( + num_heads: int, + head_size: int, + num_kv_heads: int, + sliding_window: Optional[int], + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, +) -> _Backend: """Returns which flash attention backend to use.""" if is_cpu(): return _Backend.TORCH_SDPA diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 6f90e49994fb2..e3e32d61ab04d 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -2,26 +2,29 @@ from torch import nn -from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) from vllm.model_executor.model_loader.loader import (BaseModelLoader, get_model_loader) from vllm.model_executor.model_loader.utils import ( get_architecture_class_name, get_model_architecture) -def get_model( - *, model_config: ModelConfig, load_config: LoadConfig, - device_config: DeviceConfig, parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: +def get_model(*, model_config: ModelConfig, load_config: LoadConfig, + device_config: DeviceConfig, parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig) -> nn.Module: loader = get_model_loader(load_config) return loader.load_model(model_config=model_config, device_config=device_config, lora_config=lora_config, vision_language_config=vision_language_config, parallel_config=parallel_config, - scheduler_config=scheduler_config) + scheduler_config=scheduler_config, + cache_config=cache_config) __all__ = [ diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bafa2de62e5df..fc9c8aa0af44b 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -9,9 +9,9 @@ import torch from torch import nn -from vllm.config import (DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, + LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( @@ -77,15 +77,16 @@ def _get_model_initialization_kwargs( return extra_kwargs -def _initialize_model( - model_config: ModelConfig, load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: +def _initialize_model(model_config: ModelConfig, load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig) -> nn.Module: """Initialize a model with the given configurations.""" model_class = get_model_architecture(model_config)[0] quant_config = _get_quantization_config(model_config, load_config) return model_class(config=model_config.hf_config, + cache_config=cache_config, quant_config=quant_config, **_get_model_initialization_kwargs( model_class, lora_config, vision_language_config)) @@ -103,7 +104,8 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: """Load a model with the given configurations.""" ... @@ -216,11 +218,13 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, - lora_config, vision_language_config) + lora_config, vision_language_config, + cache_config) model.load_weights( self._get_weights_iterator(model_config.model, model_config.revision, @@ -253,11 +257,13 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, - lora_config, vision_language_config) + lora_config, vision_language_config, + cache_config) # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) @@ -286,9 +292,12 @@ def _get_weights_iterator( return tensorizer_weights_iterator(tensorizer_args) def _load_model_unserialized( - self, model_config: ModelConfig, device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig] + self, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig, ) -> nn.Module: """Load an unserialized model with tensorizer. @@ -299,15 +308,19 @@ def _load_model_unserialized( with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, - lora_config, vision_language_config) + lora_config, vision_language_config, + cache_config) model.load_weights(self._get_weights_iterator()) return model.eval() def _load_model_serialized( - self, model_config: ModelConfig, device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig] + self, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig, ) -> nn.Module: """Load a serialized model with tensorizer. @@ -321,6 +334,7 @@ def _load_model_serialized( extra_kwargs = _get_model_initialization_kwargs( model_class, lora_config, vision_language_config) extra_kwargs["quant_config"] = quant_config + extra_kwargs["cache_config"] = cache_config tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config.model_class = model_class @@ -335,16 +349,19 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: self._verify_config(model_config, parallel_config) if is_vllm_serialized_tensorizer(self.tensorizer_config): return self._load_model_serialized(model_config, device_config, lora_config, - vision_language_config) + vision_language_config, + cache_config) return self._load_model_unserialized(model_config, device_config, lora_config, - vision_language_config) + vision_language_config, + cache_config) def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 796cef7c4a735..cb99939cbb17a 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -5,6 +5,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -215,6 +216,7 @@ def __init__( self, config: ArcticConfig, layer_idx: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -265,7 +267,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -288,6 +291,7 @@ def __init__( self, config: ArcticConfig, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -297,6 +301,7 @@ def __init__( self.use_residual = config.use_residual and is_moe_layer self.self_attn = ArcticAttention(config, layer_idx, + cache_config, quant_config=quant_config) self.block_sparse_moe = ArcticMoE( config, @@ -356,6 +361,7 @@ class ArcticModel(nn.Module): def __init__( self, config: ArcticConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -366,7 +372,10 @@ def __init__( config.hidden_size, org_num_embeddings=self.vocab_size) self.layers = nn.ModuleList([ - ArcticDecoderLayer(config, layer_idx, quant_config=quant_config) + ArcticDecoderLayer(config, + layer_idx, + cache_config, + quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self._attn_implementation = config._attn_implementation @@ -392,11 +401,12 @@ class ArcticForCausalLM(nn.Module): def __init__(self, config: ArcticConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, **kwargs) -> None: super().__init__() self.config = config - self.model = ArcticModel(config, quant_config) + self.model = ArcticModel(config, cache_config, quant_config) self.vocab_size = config.vocab_size self.lm_head = ParallelLMHead( self.vocab_size, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 186cee2584369..58b3405d319d1 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -26,7 +26,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -111,6 +111,7 @@ def __init__( position_embedding: str, rope_theta: float = 10000, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -162,7 +163,10 @@ def __init__( base=self.rope_theta, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, self.head_dim, self.scaling) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config) def forward( self, @@ -185,6 +189,7 @@ class BaiChuanDecoderLayer(nn.Module): def __init__(self, config: PretrainedConfig, position_embedding: str, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size @@ -197,6 +202,7 @@ def __init__(self, position_embedding=position_embedding, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) self.mlp = BaiChuanMLP( @@ -244,6 +250,7 @@ class BaiChuanModel(nn.Module): def __init__(self, config: PretrainedConfig, position_embedding: str, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -255,7 +262,8 @@ def __init__(self, config.hidden_size, ) self.layers = nn.ModuleList([ - BaiChuanDecoderLayer(config, position_embedding, quant_config) + BaiChuanDecoderLayer(config, position_embedding, cache_config, + quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -304,13 +312,15 @@ def __init__( self, config, position_embedding: str, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.model = BaiChuanModel(config, position_embedding, quant_config) + self.model = BaiChuanModel(config, position_embedding, cache_config, + quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -389,13 +399,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): if config.hidden_size == 4096: # baichuan2 7b - super().__init__(config, "ROPE", quant_config, lora_config) + super().__init__(config, "ROPE", cache_config, quant_config, + lora_config) else: # baichuan 13b, baichuan2 13b - super().__init__(config, "ALIBI", quant_config, lora_config) + super().__init__(config, "ALIBI", cache_config, quant_config, + lora_config) class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): @@ -404,7 +417,9 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - super().__init__(config, "ROPE", quant_config, lora_config) + super().__init__(config, "ROPE", cache_config, quant_config, + lora_config) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 1d7e5d2517c72..fe2de87b20dc9 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -24,6 +24,7 @@ from transformers import BloomConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn @@ -71,6 +72,7 @@ class BloomAttention(nn.Module): def __init__( self, config: BloomConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -108,7 +110,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scaling, - alibi_slopes=alibi_slopes) + alibi_slopes=alibi_slopes, + cache_config=cache_config) def forward( self, @@ -158,6 +161,7 @@ class BloomBlock(nn.Module): def __init__( self, config: BloomConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -165,7 +169,8 @@ def __init__( self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.self_attention = BloomAttention(config, quant_config) + self.self_attention = BloomAttention(config, cache_config, + quant_config) self.post_attention_layernorm = nn.LayerNorm( hidden_size, eps=config.layer_norm_epsilon) self.mlp = BloomMLP(config, quant_config) @@ -214,6 +219,7 @@ class BloomModel(nn.Module): def __init__( self, config: BloomConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -229,7 +235,7 @@ def __init__( # Transformer blocks self.h = nn.ModuleList([ - BloomBlock(config, quant_config) + BloomBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -262,12 +268,13 @@ class BloomForCausalLM(nn.Module): def __init__( self, config: BloomConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = BloomModel(config, quant_config) + self.transformer = BloomModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.word_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index e116af2ed080d..29c76682109c6 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -9,7 +9,7 @@ from torch.nn import LayerNorm from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -34,6 +34,7 @@ class GLMAttention(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -90,6 +91,7 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + cache_config=cache_config, ) def forward( @@ -167,6 +169,7 @@ class GLMBlock(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -181,7 +184,7 @@ def __init__( eps=config.layernorm_epsilon) # Self attention. - self.self_attention = GLMAttention(config, quant_config) + self.self_attention = GLMAttention(config, cache_config, quant_config) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output @@ -237,6 +240,7 @@ class GLMTransformer(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -246,8 +250,10 @@ def __init__( self.num_layers = config.num_layers # Transformer layers. - self.layers = nn.ModuleList( - [GLMBlock(config, quant_config) for i in range(self.num_layers)]) + self.layers = nn.ModuleList([ + GLMBlock(config, cache_config, quant_config) + for i in range(self.num_layers) + ]) if self.post_layer_norm: layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm @@ -282,6 +288,7 @@ class ChatGLMModel(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -292,7 +299,7 @@ def __init__( self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels - self.encoder = GLMTransformer(config, quant_config) + self.encoder = GLMTransformer(config, cache_config, quant_config) self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size) @@ -334,13 +341,14 @@ class ChatGLMForCausalLM(nn.Module): def __init__( self, config: ChatGLMConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config: ChatGLMConfig = config self.quant_config = quant_config - self.transformer = ChatGLMModel(config, quant_config) + self.transformer = ChatGLMModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.output_layer.weight self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 17c2f1223d96b..7354d11f98b15 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -29,6 +29,7 @@ from transformers import CohereConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -124,6 +125,7 @@ class CohereAttention(nn.Module): def __init__( self, config: CohereConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -180,6 +182,7 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + cache_config=cache_config, ) if self.use_qk_norm: self.q_norm = LayerNorm(param_shape=(self.num_heads, @@ -219,11 +222,14 @@ class CohereDecoderLayer(nn.Module): def __init__(self, config: CohereConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = CohereAttention(config, quant_config=quant_config) + self.self_attn = CohereAttention(config, + cache_config, + quant_config=quant_config) self.mlp = CohereMLP(config, quant_config=quant_config) self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), @@ -258,6 +264,7 @@ class CohereModel(nn.Module): def __init__( self, config: CohereConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -266,7 +273,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - CohereDecoderLayer(config, quant_config=quant_config) + CohereDecoderLayer(config, cache_config, quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = LayerNorm(param_shape=(config.hidden_size), @@ -299,6 +306,7 @@ class CohereForCausalLM(nn.Module): def __init__( self, config: CohereConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -306,7 +314,7 @@ def __init__( self.quant_config = quant_config self.logits_processor = LogitsProcessor(config.vocab_size, scale=config.logit_scale) - self.model = CohereModel(config, quant_config) + self.model = CohereModel(config, cache_config, quant_config) self.sampler = Sampler() @torch.no_grad() diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index a4a0ae50c645e..083ddf0159f71 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -5,6 +5,7 @@ import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -166,6 +167,7 @@ class DbrxAttention(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -221,6 +223,7 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + cache_config=cache_config, ) def forward( @@ -279,10 +282,12 @@ class DbrxBlock(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config) + self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config, + quant_config) self.ffn = DbrxExperts(config, quant_config) def forward( @@ -308,6 +313,7 @@ class DbrxModel(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -315,8 +321,10 @@ def __init__( config.vocab_size, config.d_model, ) - self.blocks = nn.ModuleList( - [DbrxBlock(config, quant_config) for _ in range(config.n_layers)]) + self.blocks = nn.ModuleList([ + DbrxBlock(config, cache_config, quant_config) + for _ in range(config.n_layers) + ]) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) for module in self.modules(): if hasattr(module, "bias") and isinstance(module.bias, @@ -349,13 +357,14 @@ class DbrxForCausalLM(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config self.unpadded_vocab_size = config.vocab_size - self.transformer = DbrxModel(config, quant_config) + self.transformer = DbrxModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.d_model, diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index be9a6b6813f8f..e293ee491908d 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -28,7 +28,7 @@ import torch from transformers import PretrainedConfig -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -56,12 +56,14 @@ class DeciLMForCausalLM(LlamaForCausalLM): def __init__( self, config: Optional[PretrainedConfig] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: config.num_key_value_heads = max(config.num_key_value_heads_per_layer) delattr(config, "num_key_value_heads_per_layer") super().__init__(config=config, + cache_config=cache_config, quant_config=quant_config, lora_config=lora_config) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index e5f7ba086a35d..62e04f9649915 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -28,6 +28,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -178,6 +179,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -229,7 +231,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -252,6 +255,7 @@ def __init__( self, config: PretrainedConfig, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -267,6 +271,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) if (config.n_routed_experts is not None @@ -321,6 +326,7 @@ class DeepseekModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -332,7 +338,10 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config) + DeepseekDecoderLayer(config, + layer_idx, + cache_config, + quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -360,12 +369,13 @@ class DeepseekForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = DeepseekModel(config, quant_config) + self.model = DeepseekModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 08dd69923dc6d..ab9e1994be426 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -27,6 +27,7 @@ from transformers import FalconConfig as HF_FalconConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -77,6 +78,7 @@ class FalconAttention(nn.Module): def __init__( self, config: FalconConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -168,7 +170,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.inv_norm_factor, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -229,12 +232,14 @@ class FalconDecoderLayer(nn.Module): def __init__( self, config: FalconConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.self_attention = FalconAttention(config, quant_config) + self.self_attention = FalconAttention(config, cache_config, + quant_config) self.mlp = FalconMLP(config, quant_config) self.config = config @@ -311,6 +316,7 @@ class FalconModel(nn.Module): def __init__( self, config: FalconConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -327,7 +333,7 @@ def __init__( # Transformer blocks self.h = nn.ModuleList([ - FalconDecoderLayer(config, quant_config) + FalconDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -359,12 +365,13 @@ class FalconForCausalLM(nn.Module): def __init__( self, config: FalconConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = FalconModel(config, quant_config) + self.transformer = FalconModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.word_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index bb73ff4d206da..d1502b718a773 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -22,7 +22,7 @@ from transformers import GemmaConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul @@ -107,6 +107,7 @@ def __init__(self, head_dim: int, max_position_embeddings: int = 8192, rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.hidden_size = hidden_size @@ -155,7 +156,8 @@ def __init__(self, self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -177,6 +179,7 @@ class GemmaDecoderLayer(nn.Module): def __init__( self, config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -188,6 +191,7 @@ def __init__( head_dim=config.head_dim, max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta, + cache_config=cache_config, quant_config=quant_config, ) self.mlp = GemmaMLP( @@ -236,6 +240,7 @@ class GemmaModel(nn.Module): def __init__( self, config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -246,7 +251,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - GemmaDecoderLayer(config, quant_config) + GemmaDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -309,6 +314,7 @@ class GemmaForCausalLM(nn.Module): def __init__( self, config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -316,7 +322,7 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.model = GemmaModel(config, quant_config) + self.model = GemmaModel(config, cache_config, quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 75eaebf0dbd15..0deaa58ed9eb5 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -24,6 +24,7 @@ from transformers import GPT2Config from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -45,6 +46,7 @@ class GPT2Attention(nn.Module): def __init__( self, config: GPT2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -70,7 +72,10 @@ def __init__( bias=True, quant_config=quant_config, ) - self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + cache_config=cache_config) def forward( self, @@ -122,6 +127,7 @@ class GPT2Block(nn.Module): def __init__( self, config: GPT2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -130,7 +136,7 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2Attention(config, quant_config) + self.attn = GPT2Attention(config, cache_config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPT2MLP(inner_dim, config, quant_config) @@ -163,6 +169,7 @@ class GPT2Model(nn.Module): def __init__( self, config: GPT2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -174,7 +181,7 @@ def __init__( self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPT2Block(config, quant_config) + GPT2Block(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -203,12 +210,13 @@ class GPT2LMHeadModel(nn.Module): def __init__( self, config: GPT2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = GPT2Model(config, quant_config) + self.transformer = GPT2Model(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index d057fd928fdb5..c20fb3230c394 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -25,6 +25,7 @@ from transformers import GPTBigCodeConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -46,6 +47,7 @@ class GPTBigCodeAttention(nn.Module): def __init__( self, config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -85,7 +87,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -143,6 +146,7 @@ class GPTBigCodeBlock(nn.Module): def __init__( self, config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -151,7 +155,7 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(config, quant_config) + self.attn = GPTBigCodeAttention(config, cache_config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPTBigMLP(inner_dim, config, quant_config) @@ -184,6 +188,7 @@ class GPTBigCodeModel(nn.Module): def __init__( self, config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -195,7 +200,7 @@ def __init__( self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPTBigCodeBlock(config, quant_config) + GPTBigCodeBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -224,12 +229,13 @@ class GPTBigCodeForCausalLM(nn.Module): def __init__( self, config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = GPTBigCodeModel(config, quant_config) + self.transformer = GPTBigCodeModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 8d7fe8a5beef7..5f4d8ec3d3a7a 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -23,6 +23,7 @@ from transformers import GPTJConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -45,6 +46,7 @@ class GPTJAttention(nn.Module): def __init__( self, config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -83,7 +85,10 @@ def __init__( base=rope_theta, is_neox_style=False, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config) def forward( self, @@ -135,13 +140,14 @@ class GPTJBlock(nn.Module): def __init__( self, config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() inner_dim = (4 * config.n_embd if config.n_inner is None else config.n_inner) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = GPTJAttention(config, quant_config) + self.attn = GPTJAttention(config, cache_config, quant_config) self.mlp = GPTJMLP(inner_dim, config, quant_config) def forward( @@ -169,6 +175,7 @@ class GPTJModel(nn.Module): def __init__( self, config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -178,8 +185,10 @@ def __init__( config.vocab_size, self.embed_dim, ) - self.h = nn.ModuleList( - [GPTJBlock(config, quant_config) for _ in range(config.n_layer)]) + self.h = nn.ModuleList([ + GPTJBlock(config, cache_config, quant_config) + for _ in range(config.n_layer) + ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -207,13 +216,14 @@ class GPTJForCausalLM(nn.Module): def __init__( self, config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config assert not config.tie_word_embeddings - self.transformer = GPTJModel(config, quant_config) + self.transformer = GPTJModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.n_embd, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index bab563b9c5a39..dcb52ff666c95 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -23,6 +23,7 @@ from transformers import GPTNeoXConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -45,6 +46,7 @@ class GPTNeoXAttention(nn.Module): def __init__( self, config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -84,7 +86,10 @@ def __init__( max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config) def forward( self, @@ -134,6 +139,7 @@ class GPTNeoXLayer(nn.Module): def __init__( self, config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -142,7 +148,7 @@ def __init__( eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = GPTNeoXAttention(config, quant_config) + self.attention = GPTNeoXAttention(config, cache_config, quant_config) self.mlp = GPTNeoXMLP(config, quant_config) def forward( @@ -182,6 +188,7 @@ class GPTNeoXModel(nn.Module): def __init__( self, config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -192,7 +199,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - GPTNeoXLayer(config, quant_config) + GPTNeoXLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, @@ -223,12 +230,13 @@ class GPTNeoXForCausalLM(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.gpt_neox = GPTNeoXModel(config, quant_config) + self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config) self.embed_out = ParallelLMHead( config.vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 5811cae83bf8b..65f7ddb8b082c 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -6,6 +6,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -64,6 +65,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -114,7 +116,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -136,6 +139,7 @@ class InternLMDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -151,6 +155,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) self.feed_forward = InternLM2MLP( @@ -196,6 +201,7 @@ class InternLM2Model(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -207,7 +213,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - InternLMDecoderLayer(config, quant_config) + InternLMDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -239,12 +245,13 @@ class InternLM2ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = InternLM2Model(config, quant_config) + self.model = InternLM2Model(config, cache_config, quant_config) self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index bd6a180ec8dfc..df30fd1ba0a37 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -26,6 +26,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -69,6 +70,7 @@ class JAISAttention(nn.Module): def __init__( self, config: JAISConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -108,6 +110,7 @@ def __init__( self.head_dim, scale=self.scale, alibi_slopes=alibi_slopes, + cache_config=cache_config, ) def forward( @@ -170,6 +173,7 @@ class JAISBlock(nn.Module): def __init__( self, config: JAISConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -178,7 +182,7 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = JAISAttention(config, quant_config) + self.attn = JAISAttention(config, cache_config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = JAISMLP(inner_dim, config, quant_config) @@ -211,6 +215,7 @@ class JAISModel(nn.Module): def __init__( self, config: JAISConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -228,7 +233,7 @@ def __init__( else: self.embeddings_scale = config.mup_embeddings_scale self.h = nn.ModuleList([ - JAISBlock(config, quant_config) + JAISBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -262,12 +267,13 @@ class JAISLMHeadModel(nn.Module): def __init__( self, config: JAISConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = JAISModel(config, quant_config) + self.transformer = JAISModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 127e4612b2e40..ebdc64e0e220e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -28,7 +28,7 @@ from transformers import LlamaConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -94,6 +94,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -153,7 +154,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + sliding_window=sliding_window, + cache_config=cache_config) def forward( self, @@ -176,6 +178,7 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -204,6 +207,7 @@ def __init__( quant_config=quant_config, bias=attention_bias, sliding_window=sliding_window, + cache_config=cache_config, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -251,6 +255,7 @@ class LlamaModel(nn.Module): def __init__( self, config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -267,7 +272,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, quant_config) + LlamaDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -332,12 +337,16 @@ class LlamaForCausalLM(nn.Module): def __init__( self, config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config - self.model = LlamaModel(config, quant_config, lora_config=lora_config) + self.model = LlamaModel(config, + cache_config, + quant_config, + lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index dcde4dfa0795e..3b99b337a2765 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -7,7 +7,7 @@ from transformers import CLIPVisionModel, LlavaConfig from vllm.attention import AttentionMetadata -from vllm.config import VisionLanguageConfig +from vllm.config import CacheConfig, VisionLanguageConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( @@ -62,6 +62,7 @@ class LlavaForConditionalGeneration(nn.Module): def __init__(self, config: "LlavaConfig", vision_language_config: VisionLanguageConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional["QuantizationConfig"] = None) -> None: super().__init__() self.config = config @@ -85,7 +86,8 @@ def __init__(self, projector_hidden_act=config.projector_hidden_act) self.quant_config = quant_config - self.language_model = LlamaModel(config.text_config, quant_config) + self.language_model = LlamaModel(config.text_config, cache_config, + quant_config) self.unpadded_vocab_size = config.text_config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index c90bcfbfc4707..0b85cf1c94795 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -28,7 +28,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -181,6 +181,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -234,7 +235,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -259,6 +261,7 @@ class MiniCPMDecoderLayer(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -275,6 +278,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) self.num_experts = getattr(self.config, "num_experts", 0) @@ -330,6 +334,7 @@ class MiniCPMModel(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -346,7 +351,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - MiniCPMDecoderLayer(config, quant_config) + MiniCPMDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -413,6 +418,7 @@ class MiniCPMForCausalLM(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -421,6 +427,7 @@ def __init__( self.num_experts = getattr(self.config, "num_experts", 0) self.quant_config = quant_config self.model = MiniCPMModel(config, + cache_config, quant_config, lora_config=lora_config) unpadded_vocab_size = config.vocab_size diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index efa4de7516212..113abbaa6036d 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -29,7 +29,7 @@ from vllm import _custom_ops as ops from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -252,6 +252,7 @@ def __init__(self, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, sliding_window: Optional[int] = None) -> None: super().__init__() @@ -313,6 +314,7 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window, + cache_config=cache_config, ) def forward( @@ -335,6 +337,7 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -348,6 +351,7 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, sliding_window=config.sliding_window, + cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE( num_experts=config.num_local_experts, @@ -394,6 +398,7 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -410,7 +415,9 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, quant_config=quant_config) + MixtralDecoderLayer(config, + cache_config, + quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -460,12 +467,14 @@ class MixtralForCausalLM(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.model = MixtralModel(config, + cache_config, quant_config, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 38c62afced28a..ee2626b1c1aa2 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -30,6 +30,7 @@ from transformers import MixtralConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -157,14 +158,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class MixtralAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + quant_config: Optional[QuantizationConfig] = None, + sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -215,6 +219,7 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window, + cache_config=cache_config, ) def forward( @@ -237,6 +242,7 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -250,6 +256,7 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, sliding_window=config.sliding_window, + cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config) @@ -292,6 +299,7 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -303,7 +311,9 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, quant_config=quant_config) + MixtralDecoderLayer(config, + cache_config, + quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -332,12 +342,13 @@ class MixtralForCausalLM(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = MixtralModel(config, quant_config) + self.model = MixtralModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 6fa5c5bd3014a..716ac51cde94d 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -7,6 +7,7 @@ import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn @@ -43,6 +44,7 @@ class MPTAttention(nn.Module): def __init__( self, config: MPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -107,7 +109,8 @@ def __init__( self.head_dim, scaling, alibi_slopes=alibi_slopes, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -166,12 +169,13 @@ class MPTBlock(nn.Module): def __init__( self, config: MPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.d_model self.norm_1 = nn.LayerNorm(hidden_size) - self.attn = MPTAttention(config, quant_config) + self.attn = MPTAttention(config, cache_config, quant_config) self.norm_2 = nn.LayerNorm(hidden_size) self.ffn = MPTMLP(config, quant_config) @@ -201,6 +205,7 @@ class MPTModel(nn.Module): def __init__( self, config: MPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -211,8 +216,10 @@ def __init__( config.vocab_size, config.d_model, ) - self.blocks = nn.ModuleList( - [MPTBlock(config, quant_config) for _ in range(config.n_layers)]) + self.blocks = nn.ModuleList([ + MPTBlock(config, cache_config, quant_config) + for _ in range(config.n_layers) + ]) self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): @@ -246,6 +253,7 @@ class MPTForCausalLM(nn.Module): def __init__( self, config: MPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -253,7 +261,7 @@ def __init__( assert config.tie_word_embeddings self.quant_config = quant_config - self.transformer = MPTModel(config, quant_config) + self.transformer = MPTModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index f212ea2166e1d..69f23bbfb5d0a 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -28,6 +28,7 @@ from transformers import OlmoConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -55,6 +56,7 @@ class OlmoAttention(nn.Module): def __init__( self, config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -93,7 +95,8 @@ def __init__( self.scaling = self.head_dim**-0.5 self.attn = Attention(self.num_heads, self.head_dim, - scale=self.scaling) + scale=self.scaling, + cache_config=cache_config) # Attention output projection. self.o_proj = RowParallelLinear( @@ -175,10 +178,11 @@ class OlmoDecoderLayer(nn.Module): def __init__(self, config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() # Attention block. - self.self_attn = OlmoAttention(config, quant_config) + self.self_attn = OlmoAttention(config, cache_config, quant_config) # MLP block. self.mlp = OlmoMLP(config, quant_config) @@ -217,6 +221,7 @@ class OlmoModel(nn.Module): def __init__(self, config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -224,7 +229,7 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - OlmoDecoderLayer(config, quant_config) + OlmoDecoderLayer(config, cache_config, quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, @@ -271,10 +276,11 @@ class OlmoForCausalLM(nn.Module): def __init__(self, config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.model = OlmoModel(config, quant_config) + self.model = OlmoModel(config, cache_config, quant_config) if config.tie_word_embeddings: self.lm_head_weight = self.model.embed_tokens.weight else: diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 336f765ababaa..d241756e50f4a 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -24,6 +24,7 @@ from transformers import OPTConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -61,6 +62,7 @@ def __init__( embed_dim: int, num_heads: int, bias: bool = True, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -88,7 +90,8 @@ def __init__( ) self.attn = Attention(self.num_heads, self.head_dim, - scale=self.scaling) + scale=self.scaling, + cache_config=cache_config) def forward( self, @@ -108,6 +111,7 @@ class OPTDecoderLayer(nn.Module): def __init__( self, config: OPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -117,6 +121,7 @@ def __init__( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, bias=config.enable_bias, + cache_config=cache_config, quant_config=quant_config, ) self.do_layer_norm_before = config.do_layer_norm_before @@ -181,6 +186,7 @@ class OPTDecoder(nn.Module): def __init__( self, config: OPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -226,7 +232,7 @@ def __init__( self.final_layer_norm = None self.layers = nn.ModuleList([ - OPTDecoderLayer(config, quant_config) + OPTDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -259,10 +265,11 @@ class OPTModel(nn.Module): def __init__( self, config: OPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.decoder = OPTDecoder(config, quant_config) + self.decoder = OPTDecoder(config, cache_config, quant_config) def forward( self, @@ -279,12 +286,13 @@ class OPTForCausalLM(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.model = OPTModel(config, quant_config) + self.model = OPTModel(config, cache_config, quant_config) self.lm_head_weight = self.model.decoder.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 9ab5dfb97c19a..59cd42e31b374 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -11,6 +11,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -68,6 +69,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -118,7 +120,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -140,6 +143,7 @@ class OrionDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -155,6 +159,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) self.mlp = OrionMLP( @@ -202,6 +207,7 @@ class OrionModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -213,7 +219,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - OrionDecoderLayer(config, quant_config) + OrionDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -245,12 +251,13 @@ class OrionForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = OrionModel(config, quant_config) + self.model = OrionModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 4a45879201af3..ed25a232f4208 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -42,6 +42,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -63,6 +64,7 @@ class PhiAttention(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.total_num_heads = config.num_attention_heads @@ -105,7 +107,10 @@ def __init__(self, max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config) def forward( self, @@ -155,11 +160,12 @@ class PhiLayer(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self_attn = PhiAttention(config, quant_config) + self.self_attn = PhiAttention(config, cache_config, quant_config) self.mlp = PhiMLP(config, quant_config) def forward( @@ -186,6 +192,7 @@ class PhiModel(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -193,7 +200,7 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - PhiLayer(config, quant_config) + PhiLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.final_layernorm = nn.LayerNorm(config.hidden_size, @@ -225,12 +232,13 @@ class PhiForCausalLM(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config self.quant_config = quant_config - self.model = PhiModel(config, quant_config) + self.model = PhiModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index e5e0028888c88..d158846a3a1f5 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -11,6 +11,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -68,6 +69,7 @@ def __init__( max_position_embeddings: int, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -101,7 +103,10 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, self.head_dim, self.scaling) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config) def forward( self, @@ -123,6 +128,7 @@ class QWenBlock(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -135,6 +141,7 @@ def __init__( config.max_position_embeddings, rope_theta=rope_theta, rope_scaling=rope_scaling, + cache_config=cache_config, quant_config=quant_config) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -175,6 +182,7 @@ class QWenModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -186,7 +194,7 @@ def __init__( config.hidden_size, ) self.h = nn.ModuleList([ - QWenBlock(config, quant_config) + QWenBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -218,12 +226,13 @@ class QWenLMHeadModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = QWenModel(config, quant_config) + self.transformer = QWenModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 62bc7fe22c367..31ba6441f9f7a 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -29,7 +29,7 @@ from transformers import Qwen2Config from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -87,6 +87,7 @@ def __init__(self, max_position: int = 4096 * 32, rope_theta: float = 10000, use_sliding_window: bool = False, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, sliding_window: Optional[int] = None) -> None: super().__init__() @@ -137,7 +138,8 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window) + sliding_window=self.sliding_window, + cache_config=cache_config) def forward( self, @@ -160,6 +162,7 @@ def __init__( self, config: Qwen2Config, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -175,6 +178,7 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, use_sliding_window=use_sliding_window, + cache_config=cache_config, quant_config=quant_config, sliding_window=config.sliding_window) self.mlp = Qwen2MLP( @@ -222,6 +226,7 @@ class Qwen2Model(nn.Module): def __init__( self, config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -234,7 +239,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2DecoderLayer(config, layer_idx, quant_config) + Qwen2DecoderLayer(config, layer_idx, cache_config, quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -287,6 +292,7 @@ class Qwen2ForCausalLM(nn.Module): def __init__( self, config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -294,7 +300,7 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.model = Qwen2Model(config, quant_config) + self.model = Qwen2Model(config, cache_config, quant_config) if config.tie_word_embeddings: self.lm_head_weight = self.model.embed_tokens.weight diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 8da89a2b7ba6c..2a3b0173adf8b 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -30,6 +30,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -187,6 +188,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -238,7 +240,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -261,6 +264,7 @@ def __init__( self, config: PretrainedConfig, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -276,6 +280,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) if (config.num_experts is not None @@ -328,6 +333,7 @@ class Qwen2MoeModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -339,7 +345,10 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config) + Qwen2MoeDecoderLayer(config, + layer_idx, + cache_config, + quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -369,12 +378,13 @@ class Qwen2MoeForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = Qwen2MoeModel(config, quant_config) + self.model = Qwen2MoeModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 3d4f4f700f867..8b4a5507feade 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -26,6 +26,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -72,6 +73,7 @@ class StablelmAttention(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.config = config @@ -124,7 +126,8 @@ def __init__(self, self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_key_value_heads) + num_kv_heads=self.num_key_value_heads, + cache_config=cache_config) def forward( self, @@ -146,10 +149,11 @@ class StablelmDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.self_attn = StablelmAttention(config) + self.self_attn = StablelmAttention(config, cache_config, quant_config) self.mlp = StablelmMLP(config, quant_config) norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) @@ -188,6 +192,7 @@ class StableLMEpochModel(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.embed_tokens = VocabParallelEmbedding( @@ -195,7 +200,7 @@ def __init__(self, config.hidden_size, ) self.layers = nn.ModuleList([ - StablelmDecoderLayer(config, quant_config) + StablelmDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) norm_eps = getattr(config, "norm_eps", @@ -227,12 +232,13 @@ class StablelmForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = StableLMEpochModel(config, quant_config) + self.model = StableLMEpochModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 33998e2aad5c5..3c19d63276a77 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -25,6 +25,7 @@ from transformers import Starcoder2Config from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -46,6 +47,7 @@ class Starcoder2Attention(nn.Module): def __init__(self, config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -101,6 +103,7 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window, + cache_config=cache_config, ) def forward( @@ -150,10 +153,13 @@ class Starcoder2DecoderLayer(nn.Module): def __init__(self, config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Starcoder2Attention(config, quant_config=quant_config) + self.self_attn = Starcoder2Attention(config, + cache_config, + quant_config=quant_config) self.mlp = Starcoder2MLP(config, quant_config=quant_config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -191,6 +197,7 @@ class Starcoder2Model(nn.Module): def __init__(self, config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -201,7 +208,9 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - Starcoder2DecoderLayer(config, quant_config=quant_config) + Starcoder2DecoderLayer(config, + cache_config, + quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -226,10 +235,13 @@ class Starcoder2ForCausalLM(nn.Module): def __init__(self, config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.model = Starcoder2Model(config, quant_config=quant_config) + self.model = Starcoder2Model(config, + cache_config, + quant_config=quant_config) self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size if config.tie_word_embeddings: diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 0fb2662b2f715..6ef230a8ebbca 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -27,7 +27,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -89,6 +89,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -133,7 +134,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + sliding_window=sliding_window, + cache_config=cache_config) def forward( self, @@ -155,6 +157,7 @@ class XverseDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -175,6 +178,7 @@ def __init__( quant_config=quant_config, bias=getattr(config, "bias", False), sliding_window=sliding_window, + cache_config=cache_config, ) self.mlp = XverseMLP( hidden_size=self.hidden_size, @@ -221,6 +225,7 @@ class XverseModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -237,7 +242,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - XverseDecoderLayer(config, quant_config) + XverseDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -295,13 +300,14 @@ class XverseForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config=None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = XverseModel(config, quant_config) + self.model = XverseModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 1fb63a3e47921..07d51dca226bd 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -31,7 +31,7 @@ def __init__( self.head_size = model_config.get_head_size() self.num_layers = model_config.get_num_layers(parallel_config) - self.num_heads = model_config.get_num_kv_heads(parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks @@ -43,7 +43,15 @@ def __init__( self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend(model_config.dtype) + self.attn_backend = get_attn_backend( + model_config.get_num_attention_heads(parallel_config), + self.head_size, + self.num_kv_heads, + model_config.get_sliding_window(), + model_config.dtype, + cache_config.cache_dtype, + self.block_size, + ) # Initialize the cache. self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda") @@ -56,7 +64,7 @@ def _allocate_kv_cache( ) -> List[torch.Tensor]: """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_heads, self.head_size) + num_blocks, self.block_size, self.num_kv_heads, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_layers): diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 6c8b1685dadcf..0a0b0d70cfe21 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -53,7 +53,15 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size - self.attn_backend = get_attn_backend(self.model_config.dtype) + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) # Lazy initialization. self.model: nn.Module # Set after init_Model @@ -66,7 +74,8 @@ def load_model(self) -> None: vision_language_config=self.vision_language_config, lora_config=self.lora_config, parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) def _prepare_prompt( self, @@ -158,7 +167,6 @@ def _prepare_prompt( decode_metadata=None, block_tables=torch.tensor([]), slot_mapping=slot_mapping, - kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_input) @@ -242,7 +250,6 @@ def _prepare_decode( prefill_metadata=None, decode_metadata=None, block_tables=block_tables, - kv_cache_dtype=self.kv_cache_dtype, ) return ( input_tokens, diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 5e4ae564cb57e..3ee394f9912e9 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -53,7 +53,15 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend(model_config.dtype) + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + cache_config.cache_dtype, + self.block_size, + ) # Initialize the cache. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 2d3f160c60dc1..d04bebbdc31b6 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -235,7 +235,6 @@ def prepare_input_tensors( num_decode_tokens=num_decode_tokens, prefill_metadata=prefill_attn_metadata, decode_metadata=decode_attn_metadata, - kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, attn_metadata, pooling_metadata, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f46b475bdc2db..b5e1991717b13 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -141,10 +141,18 @@ def __init__( self.graph_block_tables = np.zeros( (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), dtype=np.int32) - self.attn_backend = get_attn_backend(self.model_config.dtype) + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) # Lazy initialization - self.model: torch.nn.Module # Set after load_model + self.model: nn.Module # Set after load_model # Set if the backend is flashinfer. self.flashinfer_workspace_buffer: torch.Tensor # Set after load_model. @@ -160,6 +168,7 @@ def load_model(self) -> None: vision_language_config=self.vision_language_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, + cache_config=self.cache_config, ) self.model_memory_usage = m.consumed_memory @@ -753,7 +762,6 @@ def prepare_input_tensors( num_decode_tokens=num_decode_tokens, prefill_metadata=prefill_attn_metadata, decode_metadata=decode_attn_metadata, - kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, attn_metadata, @@ -965,7 +973,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: slot_mapping=slot_mapping[:batch_size], prefill_metadata=None, decode_metadata=decode_metadata, - kv_cache_dtype=self.kv_cache_dtype, ) if self.lora_config: From 8bc68e198c4c90ddc2e54fa76eb81c2c714bb1cd Mon Sep 17 00:00:00 2001 From: Sanger Steel Date: Mon, 13 May 2024 17:57:07 -0400 Subject: [PATCH 035/124] [Frontend] [Core] perf: Automatically detect vLLM-tensorized model, update `tensorizer` to version 2.9.0 (#4208) --- .buildkite/test-pipeline.yaml | 4 +- examples/tensorize_vllm_model.py | 200 ++++++-------- requirements-dev.txt | 2 +- setup.py | 2 +- .../tensorize_vllm_model_for_testing.py | 245 ------------------ tests/tensorizer_loader/test_tensorizer.py | 189 +++++--------- vllm/engine/arg_utils.py | 4 +- vllm/envs.py | 2 +- vllm/model_executor/model_loader/loader.py | 28 +- .../model_executor/model_loader/tensorizer.py | 106 ++++++-- 10 files changed, 259 insertions(+), 523 deletions(-) delete mode 100644 tests/tensorizer_loader/tensorize_vllm_model_for_testing.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 4feea786f38ba..3c3da41c3abf3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -60,11 +60,13 @@ steps: mirror_hardwares: [amd] commands: # install aws cli for llava_example.py - - pip install awscli + # install tensorizer for tensorize_vllm_model.py + - pip install awscli tensorizer - python3 offline_inference.py - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - python3 llava_example.py + - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - label: Kernels Test %N command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py index e2456168de9d5..8b74ae1d75a1d 100644 --- a/examples/tensorize_vllm_model.py +++ b/examples/tensorize_vllm_model.py @@ -1,23 +1,20 @@ import argparse import dataclasses +import json import os -import time import uuid from functools import partial -from typing import Type -import torch -import torch.nn as nn -from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, - TensorSerializer, stream_io) -from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor -from transformers import AutoConfig, PretrainedConfig +from tensorizer import stream_io -from vllm.distributed import initialize_model_parallel +from vllm import LLM +from vllm.distributed import (init_distributed_environment, + initialize_model_parallel) from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.model_executor.model_loader.tensorizer import TensorizerArgs -from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs, + TensorizerConfig, + serialize_vllm_model) # yapf conflicts with isort for this docstring # yapf: disable @@ -27,25 +24,25 @@ to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint, or locally. Tensor encryption and decryption is also supported, although libsodium must be installed to use it. Install vllm with tensorizer support -using `pip install vllm[tensorizer]`. +using `pip install vllm[tensorizer]`. To learn more about tensorizer, visit +https://github.com/coreweave/tensorizer To serialize a model, install vLLM from source, then run something like this from the root level of this repository: python -m examples.tensorize_vllm_model \ - --model EleutherAI/gpt-j-6B \ - --dtype float16 \ + --model facebook/opt-125m \ serialize \ - --serialized-directory s3://my-bucket/ \ - --suffix vllm + --serialized-directory s3://my-bucket \ + --suffix v1 Which downloads the model from HuggingFace, loads it into vLLM, serializes it, and saves it to your S3 bucket. A local directory can also be used. This assumes your S3 credentials are specified as environment variables -in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. -To provide S3 credentials directly, you can provide `--s3-access-key-id` and -`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this -script. +in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and +`S3_ENDPOINT_URL`. To provide S3 credentials directly, you can provide +`--s3-access-key-id` and `--s3-secret-access-key`, as well as `--s3-endpoint` +as CLI args to this script. You can also encrypt the model weights with a randomly-generated key by providing a `--keyfile` argument. @@ -57,7 +54,7 @@ --model EleutherAI/gpt-j-6B \ --dtype float16 \ deserialize \ - --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors + --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors Which downloads the model tensors from your S3 bucket and deserializes them. @@ -71,26 +68,30 @@ `python -m examples.tensorize_vllm_model deserialize --help`. -Once a model is serialized, it can be used to load the model when running the -OpenAI inference client at `vllm/entrypoints/openai/api_server.py` by providing -the `--tensorizer-uri` CLI argument that is functionally the same as the -`--path-to-tensors` argument in this script, along with `--vllm-tensorized`, to -signify that the model to be deserialized is a vLLM model, rather than a -HuggingFace `PreTrainedModel`, which can also be deserialized using tensorizer -in the same inference server, albeit without the speed optimizations. To -deserialize an encrypted file, the `--encryption-keyfile` argument can be used -to provide the path to the keyfile used to encrypt the model weights. For -information on all the arguments that can be used to configure tensorizer's -deserialization, check out the tensorizer options argument group in the -`vllm/entrypoints/openai/api_server.py` script with `--help`. - -Tensorizer can also be invoked with the `LLM` class directly to load models: +Once a model is serialized, tensorizer can be invoked with the `LLM` class +directly to load models: llm = LLM(model="facebook/opt-125m", load_format="tensorizer", - tensorizer_uri=path_to_opt_tensors, - num_readers=3, - vllm_tensorized=True) + model_loader_extra_config=TensorizerConfig( + tensorizer_uri = path_to_tensors, + num_readers=3, + ) + ) + +A serialized model can be used during model loading for the vLLM OpenAI +inference server. `model_loader_extra_config` is exposed as the CLI arg +`--model-loader-extra-config`, and accepts a JSON string literal of the +TensorizerConfig arguments desired. + +In order to see all of the available arguments usable to configure +loading with tensorizer that are given to `TensorizerConfig`, run: + +`python -m examples.tensorize_vllm_model deserialize --help` + +under the `tensorizer options` section. These can also be used for +deserialization in this example script, although `--tensorizer-uri` and +`--path-to-tensors` are functionally the same in this case. """ @@ -158,95 +159,35 @@ def parse_args(): help=("Path to a binary key to use to decrypt the model weights," " if the model was serialized with encryption")) - return parser.parse_args() - - -def make_model_contiguous(model): - # Ensure tensors are saved in memory contiguously - for param in model.parameters(): - param.data = param.data.contiguous() - - -def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: - architectures = getattr(config, "architectures", []) - for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) - if model_cls is not None: - return model_cls - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") - - -def serialize(): - - eng_args_dict = {f.name: getattr(args, f.name) for f in - dataclasses.fields(EngineArgs)} - engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) - engine = LLMEngine.from_engine_args(engine_args) + TensorizerArgs.add_cli_args(deserialize_parser) - model = (engine.model_executor.driver_worker. - model_runner.model) - - encryption_params = EncryptionParams.random() if keyfile else None - if keyfile: - with _write_stream(keyfile) as stream: - stream.write(encryption_params.key) - - with _write_stream(model_path) as stream: - serializer = TensorSerializer(stream, encryption=encryption_params) - serializer.write_module(model) - serializer.close() + return parser.parse_args() - print("Serialization complete. Model tensors saved to", model_path) - if keyfile: - print("Key saved to", keyfile) def deserialize(): - config = AutoConfig.from_pretrained(model_ref) - - with no_init_or_tensor(): - model_class = _get_vllm_model_architecture(config) - model = model_class(config) - - before_mem = get_mem_usage() - start = time.time() - - if keyfile: - with _read_stream(keyfile) as stream: - key = stream.read() - decryption_params = DecryptionParams.from_key(key) - tensorizer_args.deserializer_params['encryption'] = \ - decryption_params - - with (_read_stream(model_path)) as stream, TensorDeserializer( - stream, **tensorizer_args.deserializer_params) as deserializer: - deserializer.load_into_module(model) - end = time.time() - - # Brag about how fast we are. - total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) - duration = end - start - per_second = convert_bytes(deserializer.total_tensor_bytes / duration) - after_mem = get_mem_usage() - print( - f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s" + llm = LLM(model=args.model, + load_format="tensorizer", + model_loader_extra_config=tensorizer_config ) - print(f"Memory usage before: {before_mem}") - print(f"Memory usage after: {after_mem}") + return llm - return model args = parse_args() -s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID") - or None) -s3_secret_access_key = (args.s3_secret_access_key - or os.environ.get("S3_SECRET_ACCESS_KEY") or None) +s3_access_key_id = (getattr(args, 's3_access_key_id', None) + or os.environ.get("S3_ACCESS_KEY_ID", None)) +s3_secret_access_key = (getattr(args, 's3_secret_access_key', None) + or os.environ.get("S3_SECRET_ACCESS_KEY", None)) +s3_endpoint = (getattr(args, 's3_endpoint', None) + or os.environ.get("S3_ENDPOINT_URL", None)) -s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None) +credentials = { + "s3_access_key_id": s3_access_key_id, + "s3_secret_access_key": s3_secret_access_key, + "s3_endpoint": s3_endpoint +} _read_stream, _write_stream = (partial( stream_io.open_stream, @@ -263,20 +204,41 @@ def deserialize(): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "8080" -torch.distributed.init_process_group(world_size=1, rank=0) +init_distributed_environment(world_size=1, rank=0, local_rank=0) initialize_model_parallel() keyfile = args.keyfile if args.keyfile else None + +if args.model_loader_extra_config: + config = json.loads(args.model_loader_extra_config) + tensorizer_args = TensorizerConfig(**config)._construct_tensorizer_args() + tensorizer_args.tensorizer_uri = args.path_to_tensors +else: + tensorizer_args = None + if args.command == "serialize": + eng_args_dict = {f.name: getattr(args, f.name) for f in + dataclasses.fields(EngineArgs)} + + engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) + engine = LLMEngine.from_engine_args(engine_args) + input_dir = args.serialized_directory.rstrip('/') suffix = args.suffix if args.suffix else uuid.uuid4().hex base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" model_path = f"{base_path}/model.tensors" - serialize() + tensorizer_config = TensorizerConfig( + tensorizer_uri=model_path, + **credentials) + serialize_vllm_model(engine, tensorizer_config, keyfile) elif args.command == "deserialize": - tensorizer_args = TensorizerArgs.from_cli_args(args) - model_path = args.path_to_tensors + if not tensorizer_args: + tensorizer_config = TensorizerConfig( + tensorizer_uri=args.path_to_tensors, + encryption_keyfile = keyfile, + **credentials + ) deserialize() else: raise ValueError("Either serialize or deserialize must be specified.") diff --git a/requirements-dev.txt b/requirements-dev.txt index 796c9e37d0230..4f6c27d95fe6a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,7 +14,7 @@ types-setuptools # testing pytest -tensorizer==2.9.0 +tensorizer>=2.9.0 pytest-forked pytest-asyncio pytest-rerunfailures diff --git a/setup.py b/setup.py index 0dc8818b44a9e..a66af2c5d556f 100644 --- a/setup.py +++ b/setup.py @@ -426,7 +426,7 @@ def _read_requirements(filename: str) -> List[str]: install_requires=get_requirements(), ext_modules=ext_modules, extras_require={ - "tensorizer": ["tensorizer==2.9.0"], + "tensorizer": ["tensorizer>=2.9.0"], }, cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, package_data=package_data, diff --git a/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py deleted file mode 100644 index 0e113ab647e67..0000000000000 --- a/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py +++ /dev/null @@ -1,245 +0,0 @@ -import argparse -import dataclasses -import os -import time -import uuid -from functools import partial -from typing import Type - -import torch.nn as nn -from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, - TensorSerializer, stream_io) -from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor -from transformers import AutoConfig, PretrainedConfig - -from vllm.distributed import (init_distributed_environment, - initialize_model_parallel) -from vllm.engine.arg_utils import EngineArgs -from vllm.engine.llm_engine import LLMEngine -from vllm.model_executor.model_loader.tensorizer import TensorizerArgs -from vllm.model_executor.models import ModelRegistry - -# yapf conflicts with isort for this docstring -# yapf: disable -""" -tensorize_vllm_model.py is a script that can be used to serialize and -deserialize vLLM models. These models can be loaded using tensorizer directly -to the GPU extremely quickly. Tensor encryption and decryption is also -supported, although libsodium must be installed to use it. Install -vllm with tensorizer support using `pip install vllm[tensorizer]`. - -To serialize a model, you can run something like this: - -python tensorize_vllm_model.py \ - --model EleutherAI/gpt-j-6B \ - --dtype float16 \ - serialize \ - --serialized-directory s3://my-bucket/ \ - --suffix vllm - -Which downloads the model from HuggingFace, loads it into vLLM, serializes it, -and saves it to your S3 bucket. A local directory can also be used. - -You can also encrypt the model weights with a randomly-generated key by -providing a `--keyfile` argument. - -To deserialize a model, you can run something like this: - -python tensorize_vllm_model.py \ - --model EleutherAI/gpt-j-6B \ - --dtype float16 \ - deserialize \ - --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors - -Which downloads the model tensors from your S3 bucket and deserializes them. -To provide S3 credentials, you can provide `--s3-access-key-id` and -`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script, -the OpenAI entrypoint, as arguments for LLM(), or as environment variables -in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. - - -You can also provide a `--keyfile` argument to decrypt the model weights if -they were serialized with encryption. - -For more information on the available arguments, run -`python tensorize_vllm_model.py --help`. -""" - - -def parse_args(): - parser = argparse.ArgumentParser( - description="An example script that can be used to serialize and " - "deserialize vLLM models. These models " - "can be loaded using tensorizer directly to the GPU " - "extremely quickly. Tensor encryption and decryption is " - "also supported, although libsodium must be installed to " - "use it.") - parser = TensorizerArgs.add_cli_args(EngineArgs.add_cli_args(parser)) - subparsers = parser.add_subparsers(dest='command') - - serialize_parser = subparsers.add_parser( - 'serialize', help="Serialize a model to `--serialized-directory`") - - serialize_parser.add_argument( - "--suffix", - type=str, - required=False, - help=( - "The suffix to append to the serialized model directory, which is " - "used to construct the location of the serialized model tensors, " - "e.g. if `--serialized-directory` is `s3://my-bucket/` and " - "`--suffix` is `v1`, the serialized model tensors will be " - "saved to " - "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. " - "If none is provided, a random UUID will be used.")) - serialize_parser.add_argument( - "--serialized-directory", - type=str, - required=True) - - serialize_parser.add_argument( - "--keyfile", - type=str, - required=False, - help=("Encrypt the model weights with a randomly-generated binary key," - " and save the key at this path")) - - deserialize_parser = subparsers.add_parser( - 'deserialize', - help=("Deserialize a model from `--path-to-tensors`" - " to verify it can be loaded and used.")) - - deserialize_parser.add_argument( - "--path-to-tensors", - type=str, - required=True, - help="The local path or S3 URI to the model tensors to deserialize. ") - - deserialize_parser.add_argument( - "--keyfile", - type=str, - required=False, - help=("Path to a binary key to use to decrypt the model weights," - " if the model was serialized with encryption")) - - return parser.parse_args() - - -def make_model_contiguous(model): - # Ensure tensors are saved in memory contiguously - for param in model.parameters(): - param.data = param.data.contiguous() - - -def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: - architectures = getattr(config, "architectures", []) - for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) - if model_cls is not None: - return model_cls - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") - - -def serialize(): - eng_args_dict = {f.name: getattr(args, f.name) for f in - dataclasses.fields(EngineArgs)} - engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) - engine = LLMEngine.from_engine_args(engine_args) - - model = (engine.model_executor.driver_worker. - model_runner.model) - - encryption_params = EncryptionParams.random() if keyfile else None - if keyfile: - with _write_stream(keyfile) as stream: - stream.write(encryption_params.key) - - with _write_stream(model_path) as stream: - serializer = TensorSerializer(stream, encryption=encryption_params) - serializer.write_module(model) - serializer.close() - - print("Serialization complete. Model tensors saved to", model_path) - if keyfile: - print("Key saved to", keyfile) - - -def deserialize(): - config = AutoConfig.from_pretrained(model_ref) - - with no_init_or_tensor(): - model_class = _get_vllm_model_architecture(config) - model = model_class(config) - - before_mem = get_mem_usage() - start = time.time() - - if keyfile: - with _read_stream(keyfile) as stream: - key = stream.read() - decryption_params = DecryptionParams.from_key(key) - tensorizer_args.deserializer_params['encryption'] = \ - decryption_params - - with (_read_stream(model_path)) as stream, TensorDeserializer( - stream, **tensorizer_args.deserializer_params) as deserializer: - deserializer.load_into_module(model) - end = time.time() - - # Brag about how fast we are. - total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) - duration = end - start - per_second = convert_bytes(deserializer.total_tensor_bytes / duration) - after_mem = get_mem_usage() - print( - f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s" - ) - print(f"Memory usage before: {before_mem}") - print(f"Memory usage after: {after_mem}") - - return model - - -args = parse_args() - -s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID") - or None) -s3_secret_access_key = (args.s3_secret_access_key - or os.environ.get("S3_SECRET_ACCESS_KEY") or None) - -s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None) - -_read_stream, _write_stream = (partial( - stream_io.open_stream, - mode=mode, - s3_access_key_id=s3_access_key_id, - s3_secret_access_key=s3_secret_access_key, - s3_endpoint=s3_endpoint, -) for mode in ("rb", "wb+")) - -model_ref = args.model - -model_name = model_ref.split("/")[1] - -os.environ["MASTER_ADDR"] = "127.0.0.1" -os.environ["MASTER_PORT"] = "8080" - -init_distributed_environment(world_size=1, rank=0, local_rank=0) -initialize_model_parallel() - -keyfile = args.keyfile if args.keyfile else None - -if args.command == "serialize": - input_dir = args.serialized_directory.rstrip('/') - suffix = args.suffix if args.suffix else uuid.uuid4().hex - base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" - model_path = f"{base_path}/model.tensors" - serialize() -elif args.command == "deserialize": - tensorizer_args = TensorizerArgs.from_cli_args(args) - model_path = args.path_to_tensors - deserialize() -else: - raise ValueError("Either serialize or deserialize must be specified.") diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index ad4748c5ebe96..1579d53a7fe29 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -10,12 +10,19 @@ import torch from vllm import SamplingParams -from vllm.model_executor.model_loader.tensorizer import ( - EncryptionParams, TensorizerConfig, TensorSerializer, - is_vllm_serialized_tensorizer, load_with_tensorizer, open_stream) +# yapf: disable +from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, + TensorSerializer, + is_vllm_tensorized, + load_with_tensorizer, + open_stream, + serialize_vllm_model) from ..utils import ServerRunner +# yapf conflicts with isort for this docstring + + prompts = [ "Hello, my name is", "The president of the United States is", @@ -40,7 +47,7 @@ def is_curl_installed(): @pytest.fixture(autouse=True) def tensorizer_config(): - config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True) + config = TensorizerConfig(tensorizer_uri="vllm") return config @@ -59,47 +66,6 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config): assert result == mock_agent_instance.deserialize.return_value -def test_is_vllm_model_with_vllm_in_uri(tensorizer_config): - tensorizer_config.vllm_tensorized = True - - result = is_vllm_serialized_tensorizer(tensorizer_config) - - assert result is True - - -def test_is_vllm_model_without_vllm_in_uri(tensorizer_config): - tensorizer_config.vllm_tensorized = False - - result = is_vllm_serialized_tensorizer(tensorizer_config) - - assert result is False - - -def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path): - vllm_model = vllm_runner(model_ref) - model_path = tmp_path / (model_ref + ".tensors") - outputs = vllm_model.generate(prompts, sampling_params) - model = (vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) - with open_stream(model_path, "wb+") as stream: - serializer = TensorSerializer(stream) - serializer.write_module(model) - del vllm_model, model - gc.collect() - torch.cuda.empty_cache() - loaded_vllm_model = vllm_runner( - model_ref, - load_format="tensorizer", - model_loader_extra_config=TensorizerConfig(tensorizer_uri=model_path, - num_readers=1, - vllm_tensorized=True), - ) - deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) - - # Assumes SamplingParams being seeded ensures the outputs are deterministic - assert outputs == deserialized_outputs - - @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") def test_can_deserialize_s3(vllm_runner): model_ref = "EleutherAI/pythia-1.4b" @@ -110,7 +76,6 @@ def test_can_deserialize_s3(vllm_runner): model_loader_extra_config=TensorizerConfig( tensorizer_uri=tensorized_path, num_readers=1, - vllm_tensorized=False, s3_endpoint="object.ord1.coreweave.com", )) @@ -126,29 +91,26 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( model_path = tmp_path / (model_ref + ".tensors") key_path = tmp_path / (model_ref + ".key") outputs = vllm_model.generate(prompts, sampling_params) - model = (vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) - encryption_params = EncryptionParams.random() - with open_stream(model_path, "wb+") as stream: - serializer = TensorSerializer(stream, encryption=encryption_params) - serializer.write_module(model) - with open_stream(key_path, "wb+") as stream: - stream.write(encryption_params.key) - del vllm_model, model + config_for_serializing = TensorizerConfig(tensorizer_uri=model_path) + serialize_vllm_model(vllm_model.model.llm_engine, + config_for_serializing, + encryption_key_path=key_path) + + del vllm_model gc.collect() torch.cuda.empty_cache() - loaded_vllm_model = vllm_runner(model_ref, - load_format="tensorizer", - model_loader_extra_config=TensorizerConfig( - tensorizer_uri=model_path, - encryption_keyfile=key_path, - num_readers=1, - vllm_tensorized=True)) + + config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path, + encryption_keyfile=key_path) + + loaded_vllm_model = vllm_runner( + model_ref, + load_format="tensorizer", + model_loader_extra_config=config_for_deserializing) deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) - # Assumes SamplingParams being seeded ensures the outputs are deterministic assert outputs == deserialized_outputs @@ -169,7 +131,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, model_loader_extra_config=TensorizerConfig( tensorizer_uri=model_path, num_readers=1, - vllm_tensorized=False)) + )) deserialized_outputs = loaded_hf_model.generate_greedy( prompts, max_tokens=max_tokens) @@ -190,12 +152,11 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): # Serialize model before deserializing and binding LoRA adapters vllm_model = vllm_runner(model_ref, ) model_path = tmp_path / (model_ref + ".tensors") - model = (vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) - with open_stream(model_path, "wb+") as stream: - serializer = TensorSerializer(stream) - serializer.write_module(model) - del vllm_model, model + + serialize_vllm_model(vllm_model.model.llm_engine, + TensorizerConfig(tensorizer_uri=model_path)) + + del vllm_model gc.collect() torch.cuda.empty_cache() loaded_vllm_model = vllm_runner( @@ -204,7 +165,6 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): model_loader_extra_config=TensorizerConfig( tensorizer_uri=model_path, num_readers=1, - vllm_tensorized=True, ), enable_lora=True, max_loras=1, @@ -220,58 +180,28 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): def test_load_without_tensorizer_load_format(vllm_runner): with pytest.raises(ValueError): - vllm_runner(model_ref, - model_loader_extra_config=TensorizerConfig( - tensorizer_uri="test", vllm_tensorized=False)) - - -@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") -def test_tensorize_vllm_model(tmp_path): - # Test serialize command - serialize_args = [ - "python3", tensorize_model_for_testing_script, "--model", model_ref, - "--dtype", "float16", "serialize", "--serialized-directory", tmp_path, - "--suffix", "tests" - ] - result = subprocess.run(serialize_args, capture_output=True, text=True) - print(result.stdout) # Print the output of the serialize command - - assert result.returncode == 0, (f"Serialize command failed with output:" - f"\n{result.stdout}\n{result.stderr}") - - path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" - - # Test deserialize command - deserialize_args = [ - "python3", tensorize_model_for_testing_script, "--model", model_ref, - "--dtype", "float16", "deserialize", "--path-to-tensors", - path_to_tensors - ] - result = subprocess.run(deserialize_args, capture_output=True, text=True) - assert result.returncode == 0, (f"Deserialize command failed with output:" - f"\n{result.stdout}\n{result.stderr}") + vllm_runner( + model_ref, + model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") -def test_openai_apiserver_with_tensorizer(tmp_path): +def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): ## Serialize model - serialize_args = [ - "python3", tensorize_model_for_testing_script, "--model", model_ref, - "--dtype", "float16", "serialize", "--serialized-directory", tmp_path, - "--suffix", "tests" - ] - result = subprocess.run(serialize_args, capture_output=True, text=True) - print(result.stdout) # Print the output of the serialize command + vllm_model = vllm_runner(model_ref, ) + model_path = tmp_path / (model_ref + ".tensors") - assert result.returncode == 0, (f"Serialize command failed with output:" - f"\n{result.stdout}\n{result.stderr}") + serialize_vllm_model(vllm_model.model.llm_engine, + TensorizerConfig(tensorizer_uri=model_path)) - path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" model_loader_extra_config = { - "tensorizer_uri": path_to_tensors, - "vllm_tensorized": True + "tensorizer_uri": str(model_path), } + del vllm_model + gc.collect() + torch.cuda.empty_cache() + ## Start OpenAI API server openai_args = [ "--model", model_ref, "--dtype", "float16", "--load-format", @@ -304,10 +234,10 @@ def test_openai_apiserver_with_tensorizer(tmp_path): def test_raise_value_error_on_invalid_load_format(vllm_runner): with pytest.raises(ValueError): - vllm_runner(model_ref, - load_format="safetensors", - model_loader_extra_config=TensorizerConfig( - tensorizer_uri="test", vllm_tensorized=False)) + vllm_runner( + model_ref, + load_format="safetensors", + model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) def test_tensorizer_with_tp(vllm_runner): @@ -321,8 +251,29 @@ def test_tensorizer_with_tp(vllm_runner): model_loader_extra_config=TensorizerConfig( tensorizer_uri=tensorized_path, num_readers=1, - vllm_tensorized=False, s3_endpoint="object.ord1.coreweave.com", ), tensor_parallel_size=2, ) + + +def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): + model_ref = "facebook/opt-125m" + model_path = tmp_path / (model_ref + ".tensors") + config = TensorizerConfig(tensorizer_uri=str(model_path)) + + vllm_model = vllm_runner(model_ref) + outputs = vllm_model.generate(prompts, sampling_params) + serialize_vllm_model(vllm_model.model.llm_engine, config) + + assert is_vllm_tensorized(config) + del vllm_model + gc.collect() + torch.cuda.empty_cache() + + loaded_vllm_model = vllm_runner(model_ref, + load_format="tensorizer", + model_loader_extra_config=config) + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) + + assert outputs == deserialized_outputs diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 163723b4be364..fd5338c46c340 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -167,8 +167,8 @@ def add_cli_args( '* "dummy" will initialize the weights with random values, ' 'which is mainly for profiling.\n' '* "tensorizer" will load the weights using tensorizer from ' - 'CoreWeave which assumes tensorizer_uri is set to the location of ' - 'the serialized weights.') + 'CoreWeave. See the Tensorize vLLM Model script in the Examples' + 'section for more information.\n') parser.add_argument( '--dtype', type=str, diff --git a/vllm/envs.py b/vllm/envs.py index 91cc8f3be775c..68d8a074d0914 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -145,7 +145,7 @@ # S3 access information, used for tensorizer to load model from S3 "S3_ACCESS_KEY_ID": - lambda: os.environ.get("S3_ACCESS_KEY", None), + lambda: os.environ.get("S3_ACCESS_KEY_ID", None), "S3_SECRET_ACCESS_KEY": lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), "S3_ENDPOINT_URL": diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index fc9c8aa0af44b..b14824a359b6d 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer, + TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, tensorizer_weights_iterator) from vllm.model_executor.model_loader.utils import (get_model_architecture, set_default_torch_dtype) @@ -291,7 +291,7 @@ def _get_weights_iterator( tensorizer_args = self.tensorizer_config._construct_tensorizer_args() return tensorizer_weights_iterator(tensorizer_args) - def _load_model_unserialized( + def _load_model_serialized_cpu( self, model_config: ModelConfig, device_config: DeviceConfig, @@ -299,11 +299,12 @@ def _load_model_unserialized( vision_language_config: Optional[VisionLanguageConfig], cache_config: CacheConfig, ) -> nn.Module: - """Load an unserialized model with tensorizer. + """Load a serialized model with tensorizer to the CPU. - Unserialized here means "not serialized with tensorizer". This - should still be faster than default HuggingFace loading, but will - be slower than loading a tensorizer-serialized model. + This is only necessary when the model isn't vLLM-tensorized (see + examples/tensorize_vllm_model.py) This should still be faster than + default HuggingFace loading, but will be slower than loading a + vLLM-tensorized model. """ with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): @@ -324,8 +325,9 @@ def _load_model_serialized( ) -> nn.Module: """Load a serialized model with tensorizer. - See the examples/tensorize_vllm_model.py example " - script for serializing vLLM models.""" + Expects a vLLM-tensorized model. See the + examples/tensorize_vllm_model.py example script + for serializing vLLM models.""" with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model_class = get_model_architecture(model_config)[0] @@ -353,15 +355,15 @@ def load_model(self, *, model_config: ModelConfig, cache_config: CacheConfig) -> nn.Module: self._verify_config(model_config, parallel_config) - if is_vllm_serialized_tensorizer(self.tensorizer_config): + if is_vllm_tensorized(self.tensorizer_config): return self._load_model_serialized(model_config, device_config, lora_config, vision_language_config, cache_config) - return self._load_model_unserialized(model_config, device_config, - lora_config, - vision_language_config, - cache_config) + return self._load_model_serialized_cpu(model_config, device_config, + lora_config, + vision_language_config, + cache_config) def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 219a2a392e129..2cf4ce5f88521 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -5,6 +5,7 @@ import time import typing from dataclasses import dataclass +from functools import partial from typing import Generator, Optional, Tuple, Type, Union import torch @@ -13,6 +14,7 @@ import vllm.envs as envs from vllm.config import ModelConfig, ParallelConfig +from vllm.engine.llm_engine import LLMEngine from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -27,6 +29,11 @@ from tensorizer.stream_io import open_stream from tensorizer.utils import (convert_bytes, get_mem_usage, no_init_or_tensor) + + _read_stream, _write_stream = (partial( + open_stream, + mode=mode, + ) for mode in ("rb", "wb+")) except ImportError as e: tensorizer_error_msg = str(e) @@ -43,7 +50,7 @@ class TensorizerConfig: tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, str, bytes, os.PathLike, int] - vllm_tensorized: bool + vllm_tensorized: Optional[bool] = False verify_hash: Optional[bool] = False num_readers: Optional[int] = None encryption_keyfile: Optional[str] = None @@ -93,17 +100,11 @@ def load_with_tensorizer(tensorizer_config: TensorizerConfig, return tensorizer.deserialize() -def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool: - if tensorizer_config is None: - return False - return tensorizer_config.vllm_tensorized - - @dataclass class TensorizerArgs: tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, str, bytes, os.PathLike, int] - vllm_tensorized: bool + vllm_tensorized: Optional[bool] = False verify_hash: Optional[bool] = False num_readers: Optional[int] = None encryption_keyfile: Optional[str] = None @@ -121,7 +122,9 @@ class TensorizerArgs: vLLM model. This is used to determine the behavior of the TensorDeserializer when loading tensors from a serialized model. It is far faster to deserialize a vLLM model as it utilizes - tensorizer's optimized GPU loading. + tensorizer's optimized GPU loading. Note that this is now + deprecated, as serialized vLLM models are now automatically + inferred as vLLM models. verify_hash: If True, the hashes of each tensor will be verified against the hashes stored in the metadata. A `HashMismatchError` will be raised if any of the hashes do not match. @@ -158,6 +161,7 @@ def __post_init__(self): "encryption": self.encryption_keyfile, "num_readers": self.num_readers } + if self.encryption_keyfile: with open_stream( self.encryption_keyfile, @@ -177,7 +181,14 @@ def add_cli_args( 'tensorizer options', description=('Options for configuring the behavior of the' ' tensorizer deserializer when ' - '--load-format=tensorizer')) + 'load_format=tensorizer is specified when ' + 'initializing an LLMEngine, either via the CLI ' + 'when running the vLLM OpenAI inference server ' + 'with a JSON string passed to ' + '--model-loader-extra-config or as arguments given ' + 'to TensorizerConfig when passed to ' + 'model_loader_extra_config in the constructor ' + 'for LLMEngine.')) group.add_argument( "--tensorizer-uri", @@ -222,13 +233,6 @@ def add_cli_args( help="The endpoint for the S3 bucket. Can also be set via the " "S3_ENDPOINT_URL environment variable.", ) - group.add_argument( - "--vllm-tensorized", - action="store_true", - help="If enabled, indicates that the serialized model is a vLLM " - "model. This is used to determine the behavior of the " - "TensorDeserializer when loading tensors from a " - "serialized model.") return parser @@ -322,10 +326,9 @@ def deserialize(self): """ before_mem = get_mem_usage() start = time.perf_counter() - with open_stream( - self.tensorizer_args.tensorizer_uri, - mode="rb", - **self.tensorizer_args.stream_params, + with _read_stream( + self.tensorizer_config.tensorizer_uri, + **self.tensorizer_args.stream_params ) as stream, TensorDeserializer( stream, dtype=self.tensorizer_config.dtype, @@ -345,6 +348,7 @@ def deserialize(self): self._check_tensors_on_meta_device() self._resize_lora_embeddings() + del self.model.vllm_tensorized_marker return self.model.eval() @@ -366,3 +370,63 @@ def tensorizer_weights_iterator( for name, param in state.items(): yield name, param del state + + +def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: + """ + Infer if the model is a vLLM model by checking the weights for + a vLLM tensorized marker. + + Args: + tensorizer_config: The TensorizerConfig object containing the + tensorizer_uri to the serialized model. + + Returns: + bool: True if the model is a vLLM model, False otherwise. + """ + tensorizer_args = tensorizer_config._construct_tensorizer_args() + deserializer = TensorDeserializer(open_stream( + tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params), + **tensorizer_args.deserializer_params, + lazy_load=True) + if tensorizer_config.vllm_tensorized: + logger.warning( + "Please note that newly serialized vLLM models are automatically " + "inferred as vLLM models, so setting vllm_tensorized=True is " + "only necessary for models serialized prior to this change.") + return True + if (".vllm_tensorized_marker" in deserializer): + return True + return False + + +def get_pretensorized_vllm_model(engine: "LLMEngine") -> nn.Module: + model = (engine.model_executor.driver_worker.model_runner.model) + model.register_parameter( + "vllm_tensorized_marker", + nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False)) + return model + + +def serialize_vllm_model(engine: "LLMEngine", + tensorizer_config : TensorizerConfig, + encryption_key_path: Optional[str] = None) \ + -> nn.Module: + + model = get_pretensorized_vllm_model(engine) + tensorizer_args = tensorizer_config._construct_tensorizer_args() + encryption_params = None + if encryption_key_path is not None: + encryption_params = EncryptionParams.random() + with _write_stream(encryption_key_path, + **tensorizer_args.stream_params) as stream: + stream.write(encryption_params.key) + + with _write_stream(tensorizer_args.tensorizer_uri, + **tensorizer_args.stream_params) as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + serializer.close() + logger.info("Successfully serialized model to %s", + str(tensorizer_args.tensorizer_uri)) + return model From ce532ff45c8008c7157eb448860c13bcdd44823f Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 13 May 2024 15:00:13 -0700 Subject: [PATCH 036/124] [Speculative decoding] Improve n-gram efficiency (#4724) --- tests/spec_decode/test_ngram_worker.py | 17 ++--- vllm/config.py | 13 ++-- vllm/spec_decode/ngram_worker.py | 97 ++++++++++++++------------ vllm/spec_decode/top1_proposer.py | 63 +++++++++++++++++ 4 files changed, 132 insertions(+), 58 deletions(-) diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py index de305c4030aa9..88b40d1eb4674 100644 --- a/tests/spec_decode/test_ngram_worker.py +++ b/tests/spec_decode/test_ngram_worker.py @@ -34,8 +34,8 @@ def test_ngram_algo_correctness_for_single_no_match(): max_proposal_len=20, ) - # set ngram window (0, 3], which is window=1/2/3 - ngram_worker.set_ngram_window_size(0, 3) + # set ngram window [1, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(1, 3) prompts = [ # shall find no candidate @@ -90,8 +90,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): max_proposal_len=20, ) - # set ngram window (0, 3], which is window=1/2/3 - ngram_worker.set_ngram_window_size(0, 3) + # set ngram window [1, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(1, 3) prompts = [ # shall find no candidate @@ -128,11 +128,12 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len]) assert proposals.proposal_lens.shape == torch.Size([5]) + # the first sequence has no match so proposal_len should be overwritten to 0 assert proposals.proposal_lens.tolist( - ) == [proposal_len for _ in range(4)] + [0] + ) == [0] + [proposal_len for _ in range(3)] + [0] for i in range(proposal_len): - assert proposals.proposal_token_ids[0][i] == 0 + assert proposals.proposal_token_ids[0][i] == -1 assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1] assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3] assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5] @@ -167,8 +168,8 @@ def test_ngram_algo_correctness_for_batches_match_all(): max_proposal_len=20, ) - # set ngram window (0, 3], which is window=1/2/3 - ngram_worker.set_ngram_window_size(0, 3) + # set ngram window [0, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(1, 3) prompts = [ # shall find candidate 12,13,14,15,16 diff --git a/vllm/config.py b/vllm/config.py index fab9cfbf41a2d..435f47dc9459a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -784,12 +784,15 @@ def maybe_create_spec_config( draft_quantization = None if speculative_model == "[ngram]": - assert (ngram_prompt_lookup_max is not None - and ngram_prompt_lookup_max > 0) if ngram_prompt_lookup_min is None: - ngram_prompt_lookup_min = 0 - else: - assert ngram_prompt_lookup_max > ngram_prompt_lookup_min + ngram_prompt_lookup_min = 1 + if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1: + raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0") + if ngram_prompt_lookup_min < 1: + raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0") + if ngram_prompt_lookup_min > ngram_prompt_lookup_max: + raise ValueError(f"{ngram_prompt_lookup_min=} cannot be " + f"larger than {ngram_prompt_lookup_max=}") # TODO: current we still need extract vocab_size from target model # config, in future, we may try refactor it out, and set diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 6cd50fcc1a041..9628f7af5315a 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -77,9 +77,11 @@ def sampler_output( """ self._raise_if_unsupported(execute_model_req) - arr = [] has_spec_out = False - for seq_group_metadata in execute_model_req.seq_group_metadata_list: + token_id_list = [] + token_prob_list = [] + for idx, seq_group_metadata in enumerate( + execute_model_req.seq_group_metadata_list): seq_data = next(iter(seq_group_metadata.seq_data.values())) input_ids = torch.as_tensor(seq_data.get_token_ids(), @@ -89,59 +91,64 @@ def sampler_output( for ngram_size in range( min(self.ngram_prompt_lookup_max, input_length - 1), - self.ngram_prompt_lookup_min, + self.ngram_prompt_lookup_min - 1, -1, ): - ngram_tensor = input_ids[-1 * ngram_size:] - windows = input_ids.unfold(dimension=0, - size=ngram_size, - step=1) - matches = (windows == ngram_tensor).all(dim=1) - match_indices = matches.nonzero(as_tuple=True)[0] - if match_indices.size()[0] > 1: + ngram_tensor = input_ids[-ngram_size:] + proposal_start_idx = None + if ngram_size == 1: + # Do not match itself and do not use unfold and all + matches = (input_ids[:-1] == ngram_tensor) + else: + windows = input_ids.unfold(dimension=0, + size=ngram_size, + step=1) + # Do not match itself + matches = (windows[:-1] == ngram_tensor).all(dim=-1) + + # first_match includes "values" (bool), indicating whether + # the match is found, and "indices", indicating the index + # of the first match. + # Note that "first_match.values.item()" triggers GPU-CPU + # sync so it is a bit inefficient, but we have not found + # a better way to do this. + first_match = matches.max(dim=-1) + if first_match.values.item(): + proposal_start_idx = first_match.indices.add_(ngram_size) + spec_indices = ( + proposal_start_idx).repeat(sample_len) + torch.arange( + sample_len, device=self.device) + spec_indices.clamp_(max=input_ids.shape[-1] - 1) + res = input_ids.gather(dim=-1, index=spec_indices) + token_id_list.append(res) + token_prob_list.append( + torch.nn.functional.one_hot( + res, + num_classes=self.vocab_size).to(torch.float32)) has_spec_out = True - res = seq_data.get_token_ids() - res = res[match_indices[0] + ngram_size:match_indices[0] + - ngram_size + sample_len] - res_len = len(res) - # pad 0 towards output as sample_len tokens required - res += [0] * (sample_len - res_len) - break else: - # if no candidate found, fill with 0 - res = [0] * sample_len - - arr.append(res) + token_id_list.append(None) + token_prob_list.append(None) if not has_spec_out: return None, False - outputs = [] - token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device) - indices = token_ids.unsqueeze(2) + outputs: List[Optional[SamplerOutput]] = [] + for idx in range(len(execute_model_req.seq_group_metadata_list)): + if token_id_list[idx] is None: + outputs.append(None) + else: + outputs.append( + SamplerOutput( + outputs=None, + sampled_token_probs=token_prob_list[idx], + logprobs=torch.zeros((sample_len, self.vocab_size), + dtype=torch.float32, + device=self.device), + sampled_token_ids=token_id_list[idx], + )) - token_probs = torch.zeros( - (len(execute_model_req.seq_group_metadata_list), sample_len, - self.vocab_size), - dtype=torch.float32, - device=self.device, - ) - token_probs.scatter_(2, indices, 1) - token_logprobs = torch.zeros( - (len(execute_model_req.seq_group_metadata_list), sample_len, - self.vocab_size), - dtype=torch.float32, - device=self.device, - ) - for i in range(len(execute_model_req.seq_group_metadata_list)): - outputs.append( - SamplerOutput( - outputs=None, - sampled_token_probs=token_probs[i], - logprobs=token_logprobs[i], - sampled_token_ids=token_ids[i], - )) return outputs, False def get_spec_proposals( diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index ee9462b68dae8..6c7e22207f6b2 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -73,6 +73,14 @@ def get_proposals( execute_model_req=nonzero_execute_model_req, sample_len=proposal_len, ) + ( + proposal_lens, + maybe_sampler_output, + nonzero_proposal_len_indices, + ) = self._remove_no_proposal_seqs(proposal_lens, + maybe_sampler_output, + nonzero_proposal_len_indices, + transposed) else: # If no sequences can be speculated, set sampler output to None. maybe_sampler_output = None @@ -140,6 +148,61 @@ def _split_by_proposal_len( nonzero_proposal_len_indices, ) + def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output, + nonzero_proposal_len_indices, transposed): + """Remove sequences from nonzero_proposal_len_indices and reset + their proposal_len to 0 the draft worker does not provide a proposal + (maybe_sampler_output=None). This can avoid scoring overheads. + """ + + # If maybe_sampler_output is None, then the draft worker did not + # provide a proposal for any sequence and thus no action needed. + # Also we do not support transposed maybe_sampler_output for now + # because it seems not straightforward for draft workers outputting + # transposed sampler outputs to handle the case of no proposal. + if maybe_sampler_output is None or transposed: + return (proposal_lens, maybe_sampler_output, + nonzero_proposal_len_indices) + + new_proposal_lens: List[int] = [] + new_nonzero_proposal_len_indices: List[int] = [] + new_maybe_sampler_output: List[SamplerOutput] = [] + nonzero_proposal_len_idx_ptr = 0 + seq_idx = 0 + while seq_idx < len( + proposal_lens) and nonzero_proposal_len_idx_ptr < len( + nonzero_proposal_len_indices): + if seq_idx < nonzero_proposal_len_indices[ + nonzero_proposal_len_idx_ptr]: + # Sequence is not in the original nonzero_proposal_len_indices, + # meaning that it has a proposal length of 0 before sending to + # the draft worker. + assert proposal_lens[seq_idx] == 0 + new_proposal_lens.append(0) + else: + # Sequence is in the original nonzero_proposal_len_indices + if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None: + # but does not have a proposal from the draft worker. + new_proposal_lens.append(0) + else: + # and has a proposal from the draft worker. Add it to the + # new nonzero proposal list and keep the sampler output. + new_proposal_lens.append(proposal_lens[seq_idx]) + new_nonzero_proposal_len_indices.append(seq_idx) + new_maybe_sampler_output.append( + maybe_sampler_output[nonzero_proposal_len_idx_ptr]) + nonzero_proposal_len_idx_ptr += 1 + seq_idx += 1 + + # The remaining sequences should have proposal length of 0. + new_proposal_lens.extend(proposal_lens[seq_idx:]) + + # We assume sampler_output will not be a list of all Nones. + # In this case this function should not be called. + assert new_maybe_sampler_output + return (new_proposal_lens, new_maybe_sampler_output, + new_nonzero_proposal_len_indices) + def _merge_outputs( self, batch_size: int, From 1356df53bd5d6877358aff3d2bbd95f28f8009a4 Mon Sep 17 00:00:00 2001 From: Stephen Krider <72541272+skrider@users.noreply.github.com> Date: Mon, 13 May 2024 15:50:33 -0700 Subject: [PATCH 037/124] [Kernel] Use flash-attn for decoding (#3648) Co-authored-by: Woosuk Kwon Co-authored-by: LiuXiaoxuanPKU --- tests/kernels/test_flash_attn.py | 209 ++++++++++++++++++++++++++ tests/models/test_big_models.py | 2 +- tests/models/test_fp8.py | 10 +- vllm/attention/backends/flash_attn.py | 128 +++++++++------- vllm/attention/selector.py | 14 ++ vllm/worker/model_runner.py | 15 +- 6 files changed, 313 insertions(+), 65 deletions(-) create mode 100644 tests/kernels/test_flash_attn.py diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py new file mode 100644 index 0000000000000..89bdacc67fbc4 --- /dev/null +++ b/tests/kernels/test_flash_attn.py @@ -0,0 +1,209 @@ +from typing import List, Optional, Tuple + +import pytest +import torch +from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + +NUM_HEADS = [(16, 16), (32, 8), (64, 8)] +HEAD_SIZES = [128, 256] +BLOCK_SIZES = [16, 32] +DTYPES = [torch.float16, torch.bfloat16] + + +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: List[int], + kv_lens: List[int], + block_tables: torch.Tensor, + scale: float, + sliding_window: Optional[int] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + q = query[start_idx:start_idx + query_len] + q *= scale + + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + k = k[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + v = v[:kv_len] + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if sliding_window is not None: + sliding_window_mask = torch.triu(empty_mask, + diagonal=kv_len - + (query_len + sliding_window) + + 1).bool().logical_not() + mask |= sliding_window_mask + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) + + +@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_flash_attn_with_paged_kv( + kv_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_blocks = 128 + num_seqs = len(kv_lens) + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + key_cache = torch.randn(num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + value_cache = torch.randn_like(key_cache) + kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + num_blocks, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + output = flash_attn_with_kvcache( + q=query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + softmax_scale=scale, + causal=True, + block_table=block_tables, + cache_seqlens=kv_lens_tensor, + ).squeeze(1) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" + + +@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None]) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_varlen_with_paged_kv( + seq_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_size: int, + sliding_window: Optional[int], + dtype: torch.dtype, + block_size: int, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_blocks = 128 + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_query_len = max(query_lens) + max_kv_len = max(kv_lens) + window_size = ((sliding_window, + sliding_window) if sliding_window is not None else + (-1, -1)) + scale = head_size**-0.5 + + query = torch.randn(sum(query_lens), + num_query_heads, + head_size, + dtype=dtype) + key_cache = torch.randn(num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + value_cache = torch.randn_like(key_cache) + # Normalize the scale of the key and value caches to mitigate + # numerical instability. + key_cache /= head_size**0.5 + value_cache /= head_size**0.5 + cu_query_lens = torch.tensor([0] + query_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + cu_kv_lens = torch.tensor([0] + kv_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + num_blocks, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + output = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=block_tables, + ) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + sliding_window=sliding_window, + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index c02204f16ac68..10e7c64e34e75 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -12,7 +12,7 @@ # "Deci/DeciLM-7b", # Broken # "tiiuae/falcon-7b", # Broken "EleutherAI/gpt-j-6b", - "mosaicml/mpt-7b", + # "mosaicml/mpt-7b", # Broken # "Qwen/Qwen1.5-0.5B" # Broken, ] diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index e87a1783a83f1..664e951a89f2a 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -25,18 +25,18 @@ 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', + 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred to a', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**' + 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', ], "meta-llama/Meta-Llama-3-8B-Instruct": [ 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of', + 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index f59715bd76ede..11ecb2792ea9d 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,20 +1,16 @@ -"""Attention layer with Flash and PagedAttention. - -NOTE(woosuk): At the moment, this file includes a lot of duplicated code from -XFormers backend. The duplicated code will be removed once we use flash-attn or -flashinfer for all the attention operations. -""" +"""Attention layer with FlashAttention.""" from dataclasses import dataclass from typing import List, Optional, Tuple, Type import torch -from vllm_flash_attn import flash_attn_varlen_func +from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from vllm._C import cache_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataPerStage) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) + +_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256] class FlashAttentionBackend(AttentionBackend): @@ -38,8 +34,9 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -47,19 +44,26 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) @dataclass -class FlashAttentionMetadata(AttentionMetadataPerStage, - PagedAttentionMetadata): +class FlashAttentionMetadata(AttentionMetadataPerStage): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -105,6 +109,14 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + class FlashAttentionImpl(AttentionImpl): """ @@ -156,11 +168,15 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + if head_size not in _SUPPORTED_HEAD_SIZES: raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.") def forward( self, @@ -171,17 +187,20 @@ def forward( attn_metadata: AttentionMetadata[FlashAttentionMetadata], kv_scale: float = 1.0, ) -> torch.Tensor: - """Forward pass with FlashAttention and PagedAttention. + """Forward pass with FlashAttention. Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. + assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -189,16 +208,20 @@ def forward( value = value.view(-1, self.num_kv_heads, self.head_size) if kv_cache is not None: - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) + key_cache = kv_cache[0] + value_cache = kv_cache[1] # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, kv_scale) + cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + ) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -218,7 +241,8 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or prefill_meta.block_tables.numel() == 0: + if (kv_cache is None or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. @@ -239,38 +263,32 @@ def forward( output[:num_prefill_tokens] = out else: # prefix-enabled attention - # TODO(Hai) this triton kernel has regression issue (broke) to - # deal with different data types between KV and FP8 KV cache, - # to be addressed separately. - output[:num_prefill_tokens] = PagedAttention.forward_prefix( - query, - key, - value, - key_cache, - value_cache, - prefill_meta.block_tables, - prefill_meta.subquery_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.context_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window[0], + output[:num_prefill_tokens] = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.subquery_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=prefill_meta.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, ) + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output[num_prefill_tokens:] = PagedAttention.forward_decode( - decode_query, + output[num_prefill_tokens:] = flash_attn_with_kvcache( + decode_query.unsqueeze(1), key_cache, value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - decode_meta.max_seq_len, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - kv_scale, - ) + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + ).squeeze(1) # Reshape the output tensor. return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 06f99718a4dee..5140c3cc86a31 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -93,6 +93,20 @@ def _which_attn_to_use( "torch.float16 or torch.bfloat16.") return _Backend.XFORMERS + if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): + logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.") + return _Backend.XFORMERS + + if block_size % 16 != 0: + logger.info("Cannot use FlashAttention-2 backend for block size not " + "divisible by 16.") + return _Backend.XFORMERS + + if sliding_window is not None: + logger.info( + "Cannot use FlashAttention-2 backend due to sliding window.") + return _Backend.XFORMERS + try: import vllm_flash_attn # noqa: F401 except ImportError: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b5e1991717b13..3f7e87c1de48c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -266,20 +266,27 @@ def _prepare_prompt( # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[context_len:] - prefix_block_tables.append(computed_block_nums) + if self.attn_backend.get_name() == "flash-attn": + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + # TODO(woosuk): This is a temporary fix. We should + # provide a unified interface for different backends. + block_table = seq_group_metadata.block_tables[seq_id] + else: + block_table = computed_block_nums elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: # Prefill has chunked before. block_table = seq_group_metadata.block_tables[seq_id] - prefix_block_tables.append(block_table) else: # The first prefill. - prefix_block_tables.append([]) + block_table = [] else: - prefix_block_tables.append([]) + block_table = [] # Right now, prefill start is always 0. However, this # assumption can be changed once chunked prefill is introduced. assert context_len == 0 + prefix_block_tables.append(block_table) # actual prompt lens context_lens.append(context_len) From 33d3914b1e6d85a855da1a69193030c1915cb6f9 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 13 May 2024 16:00:27 -0700 Subject: [PATCH 038/124] [Bugfix] Fix dynamic FP8 quantization for Mixtral (#4793) --- vllm/model_executor/models/mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 113abbaa6036d..e3ac33e0452fe 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -95,7 +95,7 @@ def __init__( params_dtype=self.params_dtype, quant_config=None) - if self.use_fp8: + if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn self.w13_weight = nn.Parameter( From ac1fbf7fd2d1fdddc7b4953eeb3acae35c62766f Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 13 May 2024 16:23:54 -0700 Subject: [PATCH 039/124] [Doc] Shorten README by removing supported model list (#4796) --- README.md | 47 +++++-------------------- docs/source/models/supported_models.rst | 25 +++++++++---- 2 files changed, 28 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 524d027137aba..b704441eada9f 100644 --- a/README.md +++ b/README.md @@ -51,41 +51,14 @@ vLLM is flexible and easy to use with: - (Experimental) Prefix caching support - (Experimental) Multi-lora support -vLLM seamlessly supports many Hugging Face models, including the following architectures: - -- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.) -- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.) -- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) -- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.) -- Command-R (`CohereForAI/c4ai-command-r-v01`, etc.) -- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.) -- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.) -- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) -- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.) -- GPT-2 (`gpt2`, `gpt2-xl`, etc.) -- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) -- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.) -- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) -- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) -- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.) -- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.) -- LLaMA, Llama 2, and Meta Llama 3 (`meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) -- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.) -- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) -- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.) -- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) -- OLMo (`allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc.) -- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) -- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.) -- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) -- Phi-3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.) -- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) -- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.) -- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.) -- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.) -- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.) -- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.) -- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.) +vLLM seamlessly supports most popular open-source models on HuggingFace, including: +- Transformer-like LLMs (e.g., Llama) +- Mixture-of-Expert LLMs (e.g., Mixtral) +- Multi-modal LLMs (e.g., LLaVA) + +Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html). + +## Getting Started Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): @@ -93,9 +66,7 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get pip install vllm ``` -## Getting Started - -Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started. +Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more. - [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) - [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html) - [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index ceb658bbd5c66..142c8f8573e2f 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -16,13 +16,21 @@ Alongside each architecture, we include some popular models that use it. - Example HuggingFace Models - :ref:`LoRA ` * - :code:`AquilaForCausalLM` - - Aquila + - Aquila & Aquila2 - :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc. - ✅︎ + * - :code:`ArcticForCausalLM` + - Arctic + - :code:`Snowflake/snowflake-arctic-base`, :code:`Snowflake/snowflake-arctic-instruct`, etc. + - * - :code:`BaiChuanForCausalLM` - - Baichuan + - Baichuan & Baichuan2 - :code:`baichuan-inc/Baichuan2-13B-Chat`, :code:`baichuan-inc/Baichuan-7B`, etc. - ✅︎ + * - :code:`BloomForCausalLM` + - BLOOM, BLOOMZ, BLOOMChat + - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. + - * - :code:`ChatGLMModel` - ChatGLM - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc. @@ -39,10 +47,6 @@ Alongside each architecture, we include some popular models that use it. - DeciLM - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc. - - * - :code:`BloomForCausalLM` - - BLOOM, BLOOMZ, BLOOMChat - - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. - - * - :code:`FalconForCausalLM` - Falcon - :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc. @@ -135,6 +139,15 @@ Alongside each architecture, we include some popular models that use it. - StableLM - :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc. - + * - :code:`Starcoder2ForCausalLM` + - Starcoder2 + - :code:`bigcode/starcoder2-3b`, :code:`bigcode/starcoder2-7b`, :code:`bigcode/starcoder2-15b`, etc. + - + * - :code:`XverseForCausalLM` + - Xverse + - :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc. + - + If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. Otherwise, please refer to :ref:`Adding a New Model ` for instructions on how to implement support for your model. From 4bfa7e7f75eb5b1a397c93aeea1dea1afa867b2a Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 14 May 2024 08:47:42 +0800 Subject: [PATCH 040/124] [Doc] Add API reference for offline inference (#4710) --- docs/source/index.rst | 8 +++++++- docs/source/offline_inference/llm.rst | 6 ++++++ .../source/{dev => offline_inference}/sampling_params.rst | 4 ++-- docs/source/serving/openai_compatible_server.md | 4 ++-- 4 files changed, 17 insertions(+), 5 deletions(-) create mode 100644 docs/source/offline_inference/llm.rst rename docs/source/{dev => offline_inference}/sampling_params.rst (55%) diff --git a/docs/source/index.rst b/docs/source/index.rst index 4022c590843e6..e1e81778dbdb7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -67,6 +67,13 @@ Documentation getting_started/quickstart getting_started/examples/examples_index +.. toctree:: + :maxdepth: 1 + :caption: Offline Inference + + offline_inference/llm + offline_inference/sampling_params + .. toctree:: :maxdepth: 1 :caption: Serving @@ -101,7 +108,6 @@ Documentation :maxdepth: 2 :caption: Developer Documentation - dev/sampling_params dev/engine/engine_index dev/kernel/paged_attention dev/dockerfile/dockerfile diff --git a/docs/source/offline_inference/llm.rst b/docs/source/offline_inference/llm.rst new file mode 100644 index 0000000000000..1a443ea406994 --- /dev/null +++ b/docs/source/offline_inference/llm.rst @@ -0,0 +1,6 @@ +LLM Class +========== + +.. autoclass:: vllm.LLM + :members: + :show-inheritance: diff --git a/docs/source/dev/sampling_params.rst b/docs/source/offline_inference/sampling_params.rst similarity index 55% rename from docs/source/dev/sampling_params.rst rename to docs/source/offline_inference/sampling_params.rst index ef3d1509bda6d..f645941a6c022 100644 --- a/docs/source/dev/sampling_params.rst +++ b/docs/source/offline_inference/sampling_params.rst @@ -1,5 +1,5 @@ -Sampling Params -=============== +Sampling Parameters +=================== .. autoclass:: vllm.SamplingParams :members: diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 15a8761eb5738..a775c6addf1d9 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -48,7 +48,7 @@ completion = client.chat.completions.create( ``` ### Extra Parameters for Chat API -The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. +The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python @@ -65,7 +65,7 @@ The following extra parameters are supported: ``` ### Extra Parameters for Completions API -The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. +The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python From c579b750a083931ad03ecac898aca5ad67c6c59c Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 13 May 2024 18:48:00 -0700 Subject: [PATCH 041/124] [Doc] Add meetups to the doc (#4798) --- docs/source/community/meetups.rst | 12 ++++++++++++ docs/source/index.rst | 7 +++++++ 2 files changed, 19 insertions(+) create mode 100644 docs/source/community/meetups.rst diff --git a/docs/source/community/meetups.rst b/docs/source/community/meetups.rst new file mode 100644 index 0000000000000..fa1a26521814d --- /dev/null +++ b/docs/source/community/meetups.rst @@ -0,0 +1,12 @@ +.. _meetups: + +vLLM Meetups +============ + +We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: + +- `The third vLLM meetup `_, with Roblox, April 2nd 2024, `slides `_. +- `The second vLLM meetup `_, with IBM Research, January 31st 2024, `slides `_, `video (vLLM Update) `_, `video (IBM Research & torch.compile) `_. +- `The first vLLM meetup `_, with a16z, October 5th, 2023, `slides `_. + +We are always looking for speakers and sponsors at San Francisco Bay Area and potentially other locations. If you are interested in speaking or sponsoring, please contact us at `vllm-questions@lists.berkeley.edu `_. diff --git a/docs/source/index.rst b/docs/source/index.rst index e1e81778dbdb7..bab00e28e4018 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -50,6 +50,7 @@ For more information, check out the following: * `vLLM announcing blog post `_ (intro to PagedAttention) * `vLLM paper `_ (SOSP 2023) * `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency `_ by Cade Daniel et al. +* :ref:`vLLM Meetups `. @@ -112,6 +113,12 @@ Documentation dev/kernel/paged_attention dev/dockerfile/dockerfile +.. toctree:: + :maxdepth: 2 + :caption: Community + + community/meetups + Indices and tables ================== From ccb63a8245bceb9e6ba260eeef41b54ca8bdb370 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 14 May 2024 05:34:33 -0700 Subject: [PATCH 042/124] [Core][Hash][Automatic Prefix caching] Accelerating the hashing function by avoiding deep copies (#4696) --- benchmarks/overheads/benchmark_hashing.py | 63 +++++++++++++++++++++++ vllm/sequence.py | 16 +++++- 2 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 benchmarks/overheads/benchmark_hashing.py diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py new file mode 100644 index 0000000000000..c846e47de1fcf --- /dev/null +++ b/benchmarks/overheads/benchmark_hashing.py @@ -0,0 +1,63 @@ +import argparse +import cProfile +import pstats + +from vllm import LLM, SamplingParams + +# A very long prompt, total number of tokens is about 15k. +LONG_PROMPT = ["You are an expert in large language models, aren't you?" + ] * 1000 +LONG_PROMPT = ' '.join(LONG_PROMPT) + + +def main(args): + llm = LLM( + model=args.model, + enforce_eager=True, + enable_prefix_caching=True, + tensor_parallel_size=args.tensor_parallel_size, + use_v2_block_manager=args.use_v2_block_manager, + ) + + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) + profiler = cProfile.Profile() + + print("------warm up------") + for i in range(3): + output = llm.generate(LONG_PROMPT, sampling_params) + print(output[0].outputs[0].text) + + print("------start generating------") + for i in range(3): + profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)', + globals(), locals()) + + # analyze the runtime of hashing function + stats = pstats.Stats(profiler) + stats.sort_stats('cumulative') + total_time = 0 + total_calls = 0 + for func in stats.stats: + if 'hash_of_block' in func[2]: + total_time = stats.stats[func][3] + total_calls = stats.stats[func][0] + percentage = (total_time / stats.total_tt) * 100 + print(f"Hashing took {total_time:.2f} seconds," + f"{percentage:.2f}% of the total runtime.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Benchmark the performance of hashing function in' + 'automatic prefix caching.') + parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') + parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) + parser.add_argument('--output-len', type=int, default=10) + parser.add_argument('--enable-prefix-caching', + action='store_true', + help='enable prefix caching') + parser.add_argument('--use-v2-block-manager', + action='store_true', + help='Use BlockSpaceMangerV2') + args = parser.parse_args() + main(args) diff --git a/vllm/sequence.py b/vllm/sequence.py index 46ac33b7ecabd..12e930c27173e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -121,6 +121,7 @@ def __init__( output_token_ids = [] self.prompt_token_ids = prompt_token_ids + self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) self.output_token_ids = output_token_ids self.cumulative_logprob = 0.0 # The number of tokens that are computed (that run against the model). @@ -143,6 +144,17 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self.prompt_token_ids + self.output_token_ids + def get_prefix_token_ids( + self, num_tokens: int + ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: + """Get prefix tokens, and make the return value hashable""" + prompt_length = len(self.prompt_token_ids) + if num_tokens > prompt_length: + return (self._prompt_token_ids_tuple, + tuple(self.output_token_ids[:num_tokens - prompt_length])) + else: + return (self._prompt_token_ids_tuple[:num_tokens], None) + def get_num_computed_tokens(self) -> int: """Return the number of prefill tokens that are already computed.""" return self._num_computed_tokens @@ -253,8 +265,8 @@ def hash_of_block(self, logical_idx: int) -> int: # TODO: The current hashing function is O(L^2). We should optimize # this in the future. num_tokens = self.num_hashed_tokens_of_block(logical_idx) - return hash( - (tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id)) + hashed_tokens = self.data.get_prefix_token_ids(num_tokens) + return hash((hashed_tokens, self.lora_int_id)) def num_hashed_tokens_of_block(self, logical_idx: int): return logical_idx * self.block_size + self.block_size From dc72402b5785a6ffadff59d4e018661278d4b028 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 15 May 2024 00:57:08 +0800 Subject: [PATCH 043/124] [Bugfix][Doc] Fix CI failure in docs (#4804) This PR fixes the CI failure introduced by #4798. The failure originates from having duplicate target names in reST, and is fixed by changing the ref targets to anonymous ones. For more information, see this discussion. I have also changed the format of the links to be more distinct from each other. --- docs/source/community/meetups.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/community/meetups.rst b/docs/source/community/meetups.rst index fa1a26521814d..f371194781de3 100644 --- a/docs/source/community/meetups.rst +++ b/docs/source/community/meetups.rst @@ -5,8 +5,8 @@ vLLM Meetups We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: -- `The third vLLM meetup `_, with Roblox, April 2nd 2024, `slides `_. -- `The second vLLM meetup `_, with IBM Research, January 31st 2024, `slides `_, `video (vLLM Update) `_, `video (IBM Research & torch.compile) `_. -- `The first vLLM meetup `_, with a16z, October 5th, 2023, `slides `_. +- `The third vLLM meetup `__, with Roblox, April 2nd 2024. `[Slides] `__ +- `The second vLLM meetup `__, with IBM Research, January 31st 2024. `[Slides] `__ `[Video (vLLM Update)] `__ `[Video (IBM Research & torch.compile)] `__ +- `The first vLLM meetup `__, with a16z, October 5th 2023. `[Slides] `__ -We are always looking for speakers and sponsors at San Francisco Bay Area and potentially other locations. If you are interested in speaking or sponsoring, please contact us at `vllm-questions@lists.berkeley.edu `_. +We are always looking for speakers and sponsors at San Francisco Bay Area and potentially other locations. If you are interested in speaking or sponsoring, please contact us at `vllm-questions@lists.berkeley.edu `__. From 676a99982fe9aabe72fd52a91e08988a653a7359 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 14 May 2024 10:38:59 -0700 Subject: [PATCH 044/124] [Core] Add MultiprocessingGPUExecutor (#4539) Co-authored-by: SAHIL SUNEJA --- .buildkite/test-pipeline.yaml | 12 +- .../test_basic_distributed_correctness.py | 17 ++- .../test_chunked_prefill_distributed.py | 4 + tests/lora/test_mixtral.py | 3 +- vllm/config.py | 38 +++-- vllm/engine/arg_utils.py | 16 +- vllm/engine/async_llm_engine.py | 16 +- vllm/engine/llm_engine.py | 10 +- vllm/executor/multiproc_gpu_executor.py | 140 ++++++++++++++++++ vllm/executor/ray_gpu_executor.py | 4 +- vllm/executor/ray_utils.py | 4 +- 11 files changed, 225 insertions(+), 39 deletions(-) create mode 100644 vllm/executor/multiproc_gpu_executor.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 3c3da41c3abf3..eeb6eaa2165bc 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -34,10 +34,14 @@ steps: mirror_hardwares: [amd] commands: - pytest -v -s distributed/test_pynccl_library.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - label: Distributed Tests (Multiple Groups) working_dir: "/vllm-workspace/tests" diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index d63f015511ada..3ba5cea389c2f 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -20,6 +20,7 @@ MODELS = [ os.environ["TEST_DIST_MODEL"], ] +DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" @@ -36,19 +37,21 @@ def test_models( dtype: str, max_tokens: int, ) -> None: - enforce_eager = False + distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) - if backend_by_env_var == "FLASHINFER": - enforce_eager = True + enforce_eager = backend_by_env_var == "FLASHINFER" hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, - dtype=dtype, - tensor_parallel_size=2, - enforce_eager=enforce_eager) + vllm_model = vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + enforce_eager=enforce_eager, + distributed_executor_backend=distributed_executor_backend) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index 737b1f3169519..db938cc613c6b 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -19,6 +19,7 @@ MODELS = [ os.environ["TEST_DIST_MODEL"], ] +DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -36,6 +37,8 @@ def test_models( max_tokens: int, chunked_prefill_token_size: int, ) -> None: + distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) + # Add a chunked prefill config. max_num_seqs = min(chunked_prefill_token_size, 256) assert chunked_prefill_token_size != -1 @@ -53,6 +56,7 @@ def test_models( max_num_seqs=max_num_seqs, enable_chunked_prefill=enable_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index 4d74722aaa926..f6a8a50fa9e50 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -38,8 +38,7 @@ def test_mixtral_lora(mixtral_lora_files, tp_size): enable_lora=True, max_num_seqs=16, max_loras=4, - tensor_parallel_size=tp_size, - worker_use_ray=True) + tensor_parallel_size=tp_size) expected_lora_output = [ "give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501 diff --git a/vllm/config.py b/vllm/config.py index 435f47dc9459a..26edd4567b9ac 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -521,9 +521,7 @@ class ParallelConfig: Args: pipeline_parallel_size: Number of pipeline parallel groups. tensor_parallel_size: Number of tensor parallel groups. - worker_use_ray: Whether to use Ray for model workers. Will be set to - True if either pipeline_parallel_size or tensor_parallel_size is - greater than 1. + worker_use_ray: Deprecated, use distributed_executor_backend instead. max_parallel_loading_workers: Maximum number of multiple batches when load model sequentially. To avoid RAM OOM when using tensor parallel and large models. @@ -533,22 +531,27 @@ class ParallelConfig: If None, will use synchronous tokenization. ray_workers_use_nsight: Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. + distributed_executor_backend: Backend to use for distributed model + workers, either "ray" or "mp" (multiprocessing). If either + pipeline_parallel_size or tensor_parallel_size is greater than 1, + will default to "ray" if Ray is installed or "mp" otherwise. """ def __init__( self, pipeline_parallel_size: int, tensor_parallel_size: int, - worker_use_ray: bool, + worker_use_ray: Optional[bool] = None, max_parallel_loading_workers: Optional[int] = None, disable_custom_all_reduce: bool = False, tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, ray_workers_use_nsight: bool = False, placement_group: Optional["PlacementGroup"] = None, + distributed_executor_backend: Optional[str] = None, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size - self.worker_use_ray = worker_use_ray + self.distributed_executor_backend = distributed_executor_backend self.max_parallel_loading_workers = max_parallel_loading_workers self.disable_custom_all_reduce = disable_custom_all_reduce self.tokenizer_pool_config = tokenizer_pool_config @@ -556,14 +559,29 @@ def __init__( self.placement_group = placement_group self.world_size = pipeline_parallel_size * self.tensor_parallel_size - if self.world_size > 1: - self.worker_use_ray = True + if worker_use_ray: + if self.distributed_executor_backend is None: + self.distributed_executor_backend = "ray" + elif self.distributed_executor_backend != "ray": + raise ValueError(f"worker-use-ray can't be used with " + f"distributed executor backend " + f"'{self.distributed_executor_backend}'.") + + if self.distributed_executor_backend is None and self.world_size > 1: + from vllm.executor import ray_utils + ray_found = ray_utils.ray is not None + self.distributed_executor_backend = "ray" if ray_found else "mp" + self._verify_args() def _verify_args(self) -> None: if self.pipeline_parallel_size > 1: raise NotImplementedError( "Pipeline parallelism is not supported yet.") + if self.distributed_executor_backend not in ("ray", "mp", None): + raise ValueError( + "Unrecognized distributed executor backend. Supported values " + "are 'ray' or 'mp'.") if not self.disable_custom_all_reduce and self.world_size > 1: if is_hip(): self.disable_custom_all_reduce = True @@ -575,7 +593,8 @@ def _verify_args(self) -> None: logger.info( "Disabled the custom all-reduce kernel because it is not " "supported with pipeline parallelism.") - if self.ray_workers_use_nsight and not self.worker_use_ray: + if self.ray_workers_use_nsight and ( + not self.distributed_executor_backend == "ray"): raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") @@ -887,7 +906,8 @@ def create_draft_parallel_config( pipeline_parallel_size=target_parallel_config. pipeline_parallel_size, tensor_parallel_size=target_parallel_config.tensor_parallel_size, - worker_use_ray=target_parallel_config.worker_use_ray, + distributed_executor_backend=target_parallel_config. + distributed_executor_backend, max_parallel_loading_workers=target_parallel_config. max_parallel_loading_workers, disable_custom_all_reduce=target_parallel_config. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fd5338c46c340..195d9e1b33e3c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -34,6 +34,7 @@ class EngineArgs: seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False + distributed_executor_backend: Optional[str] = None pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None @@ -221,10 +222,17 @@ def add_cli_args( ' Can be overridden per request via guided_decoding_backend' ' parameter.') # Parallel arguments - parser.add_argument('--worker-use-ray', - action='store_true', - help='Use Ray for distributed serving, will be ' - 'automatically set when using more than 1 GPU.') + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp'], + default=EngineArgs.distributed_executor_backend, + help='Backend to use for distributed serving. When more than 1 GPU ' + 'is used, will be automatically set to "ray" if installed ' + 'or "mp" (multiprocessing) otherwise.') + parser.add_argument( + '--worker-use-ray', + action='store_true', + help='Deprecated, use --distributed-executor-backend=ray.') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a31f10b7748d3..8a37bac02823a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -348,27 +348,31 @@ def from_engine_args( """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. engine_config = engine_args.create_engine_config() + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) if engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutorAsync executor_class = NeuronExecutorAsync elif engine_config.device_config.device_type == "cpu": - assert not engine_config.parallel_config.worker_use_ray, ( - "Ray is not supported with the CPU backend.") + assert distributed_executor_backend is None, ( + "Distributed execution is not supported with the CPU backend.") from vllm.executor.cpu_executor import CPUExecutorAsync executor_class = CPUExecutorAsync - elif engine_config.parallel_config.worker_use_ray: + elif distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync executor_class = RayGPUExecutorAsync + elif distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutorAsync) + executor_class = MultiprocessingGPUExecutorAsync else: - assert engine_config.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") from vllm.executor.gpu_executor import GPUExecutorAsync executor_class = GPUExecutorAsync # Create the async LLM engine. engine = cls( - engine_config.parallel_config.worker_use_ray, + distributed_executor_backend == "ray", engine_args.engine_use_ray, **engine_config.to_dict(), executor_class=executor_class, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e258a3f4afd54..f6a5284093c1c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -274,6 +274,8 @@ def from_engine_args( """Creates an LLM engine from the engine arguments.""" # Create the engine configs. engine_config = engine_args.create_engine_config() + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) # Initialize the cluster and specify the executor class. if engine_config.device_config.device_type == "neuron": @@ -282,13 +284,15 @@ def from_engine_args( elif engine_config.device_config.device_type == "cpu": from vllm.executor.cpu_executor import CPUExecutor executor_class = CPUExecutor - elif engine_config.parallel_config.worker_use_ray: + elif distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutor executor_class = RayGPUExecutor + elif distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutor) + executor_class = MultiprocessingGPUExecutor else: - assert engine_config.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") from vllm.executor.gpu_executor import GPUExecutor executor_class = GPUExecutor diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py new file mode 100644 index 0000000000000..2a7b99c9dcbe1 --- /dev/null +++ b/vllm/executor/multiproc_gpu_executor.py @@ -0,0 +1,140 @@ +import asyncio +import os +from functools import partial +from typing import Any, Dict, Optional, Tuple + +from vllm.executor.distributed_gpu_executor import ( # yapf: disable + DistributedGPUExecutor, DistributedGPUExecutorAsync) +from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, + ResultHandler, WorkerMonitor) +from vllm.logger import init_logger +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + get_vllm_instance_id, make_async) + +logger = init_logger(__name__) + + +class MultiprocessingGPUExecutor(DistributedGPUExecutor): + """Python multiprocessing-based multi-GPU executor""" + + def _init_executor(self) -> None: + assert ( + not self.speculative_config + ), "Speculative decoding not yet supported for MultiProcGPU backend." + + # Create the parallel GPU workers. + world_size = self.parallel_config.tensor_parallel_size + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" not in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = (",".join( + map(str, range(world_size)))) + + # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers + os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() + + from torch.cuda import device_count + assert world_size <= device_count(), ( + "please set tensor_parallel_size to less than max local gpu count") + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + if world_size == 1: + self.workers = [] + else: + result_handler = ResultHandler() + self.workers = [ + ProcessWorkerWrapper( + result_handler, + partial( + self._create_worker, + rank=rank, + local_rank=rank, + distributed_init_method=distributed_init_method, + )) for rank in range(1, world_size) + ] + + self.worker_monitor = WorkerMonitor(self.workers, result_handler) + result_handler.start() + self.worker_monitor.start() + + self.driver_worker = self._create_worker( + distributed_init_method=distributed_init_method) + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + def shutdown(self): + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() + + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + # Start the workers first. + worker_outputs = [ + worker.execute_method(method, *args, **kwargs) + for worker in self.workers + ] + + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + # Start the driver worker after all the ray workers. + driver_worker_method = getattr(self.driver_worker, method) + driver_worker_output = driver_worker_method(*driver_args, + **driver_kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + if not self.worker_monitor.is_alive(): + raise RuntimeError("Worker processes are not running") + + +class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor, + DistributedGPUExecutorAsync): + + async def _run_workers_async( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + driver_executor = make_async(getattr(self.driver_worker, method)) + + # Run all the workers asynchronously. + coros = [driver_executor(*driver_args, **driver_kwargs)] + [ + worker.execute_method_async(method, *args, **kwargs) + for worker in self.workers + ] + + return await asyncio.gather(*coros) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index afc1c886722e6..9cb03ec8c3f5a 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -31,7 +31,7 @@ def _init_executor(self) -> None: assert (not self.speculative_config ), "Speculative decoding not yet supported for RayGPU backend." - assert self.parallel_config.worker_use_ray + assert self.parallel_config.distributed_executor_backend == "ray" placement_group = self.parallel_config.placement_group # Disable Ray usage stats collection. @@ -264,7 +264,7 @@ def _compiled_ray_dag(self): f"required, but found {current_version}") from ray.dag import InputNode, MultiOutputNode - assert self.parallel_config.worker_use_ray + assert self.parallel_config.distributed_executor_backend == "ray" # Right now, compiled DAG requires at least 1 arg. We send # a dummy value for now. It will be fixed soon. diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 9db3ae2ff8298..4704f5f1b1a10 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -44,7 +44,7 @@ def execute_model_compiled_dag_remote(self, ignored): except ImportError as e: logger.warning( - "Failed to import Ray with %r. For distributed inference, " + "Failed to import Ray with %r. For multi-node inference, " "please install Ray with `pip install ray`.", e) ray = None # type: ignore RayWorkerWrapper = None # type: ignore @@ -67,7 +67,7 @@ def initialize_ray_cluster( """ if ray is None: raise ImportError( - "Ray is not installed. Please install Ray to use distributed " + "Ray is not installed. Please install Ray to use multi-node " "serving.") # Connect to a ray cluster. From 29bc01bf3bc26642e4cee15ebd36a6ce5799326d Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Tue, 14 May 2024 15:33:06 -0700 Subject: [PATCH 045/124] Add 4th meetup announcement to readme (#4817) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index b704441eada9f..4fe5d9630f858 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ Easy, fast, and cheap LLM serving for everyone

*Latest News* 🔥 +- [2024/05] We are hosting [the fourth vLLM meetup](https://lu.ma/event/manage/evt-A064fGpj52fviSn) with BentoML and Cloudflare on June 11! Please register [here](https://lu.ma/agivllm). - [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing). - [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing). - [2024/01] Added ROCm 6.0 support to vLLM. From 8a7cc254a064b8d42bf4de7a9c3f29552240dfd9 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Wed, 15 May 2024 11:52:45 +0900 Subject: [PATCH 046/124] Revert "[Kernel] Use flash-attn for decoding (#3648)" (#4820) Lora 3 & 4 test seems to have illegal memory access failure after this commit; [2024-05-14 23:51:18,182 E 22 22] logging.cc:101: Unhandled exception: N3c105ErrorE. what(): CUDA error: an illegal memory access was encountered
Exmaple: https://buildkite.com/vllm/ci/builds/7382#018f793d-1527-4e1c-ab59-c3a34ec55241 This reverts commit 1356df5. FILL IN THE PR DESCRIPTION HERE FIX #xxxx (link existing issues this PR will resolve) --- tests/kernels/test_flash_attn.py | 209 -------------------------- tests/models/test_big_models.py | 2 +- tests/models/test_fp8.py | 10 +- vllm/attention/backends/flash_attn.py | 128 +++++++--------- vllm/attention/selector.py | 14 -- vllm/worker/model_runner.py | 15 +- 6 files changed, 65 insertions(+), 313 deletions(-) delete mode 100644 tests/kernels/test_flash_attn.py diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py deleted file mode 100644 index 89bdacc67fbc4..0000000000000 --- a/tests/kernels/test_flash_attn.py +++ /dev/null @@ -1,209 +0,0 @@ -from typing import List, Optional, Tuple - -import pytest -import torch -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache - -NUM_HEADS = [(16, 16), (32, 8), (64, 8)] -HEAD_SIZES = [128, 256] -BLOCK_SIZES = [16, 32] -DTYPES = [torch.float16, torch.bfloat16] - - -def ref_paged_attn( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - query_lens: List[int], - kv_lens: List[int], - block_tables: torch.Tensor, - scale: float, - sliding_window: Optional[int] = None, -) -> torch.Tensor: - num_seqs = len(query_lens) - block_tables = block_tables.cpu().numpy() - _, block_size, num_kv_heads, head_size = key_cache.shape - - outputs = [] - start_idx = 0 - for i in range(num_seqs): - query_len = query_lens[i] - kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] - q *= scale - - num_kv_blocks = (kv_len + block_size - 1) // block_size - block_indices = block_tables[i, :num_kv_blocks] - - k = key_cache[block_indices].view(-1, num_kv_heads, head_size) - k = k[:kv_len] - v = value_cache[block_indices].view(-1, num_kv_heads, head_size) - v = v[:kv_len] - - if q.shape[1] != k.shape[1]: - k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) - v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) - attn = torch.einsum("qhd,khd->hqk", q, k).float() - empty_mask = torch.ones(query_len, kv_len) - mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() - if sliding_window is not None: - sliding_window_mask = torch.triu(empty_mask, - diagonal=kv_len - - (query_len + sliding_window) + - 1).bool().logical_not() - mask |= sliding_window_mask - attn.masked_fill_(mask, float("-inf")) - attn = torch.softmax(attn, dim=-1).to(v.dtype) - out = torch.einsum("hqk,khd->qhd", attn, v) - - outputs.append(out) - start_idx += query_len - - return torch.cat(outputs, dim=0) - - -@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode -def test_flash_attn_with_paged_kv( - kv_lens: List[Tuple[int, int]], - num_heads: Tuple[int, int], - head_size: int, - dtype: torch.dtype, - block_size: int, -) -> None: - torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) - num_blocks = 128 - num_seqs = len(kv_lens) - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_kv_len = max(kv_lens) - scale = head_size**-0.5 - - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - value_cache = torch.randn_like(key_cache) - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - - output = flash_attn_with_kvcache( - q=query.unsqueeze(1), - k_cache=key_cache, - v_cache=value_cache, - softmax_scale=scale, - causal=True, - block_table=block_tables, - cache_seqlens=kv_lens_tensor, - ).squeeze(1) - - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - ) - assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" - - -@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("sliding_window", [None]) -@pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode -def test_varlen_with_paged_kv( - seq_lens: List[Tuple[int, int]], - num_heads: Tuple[int, int], - head_size: int, - sliding_window: Optional[int], - dtype: torch.dtype, - block_size: int, -) -> None: - torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) - num_blocks = 128 - num_seqs = len(seq_lens) - query_lens = [x[0] for x in seq_lens] - kv_lens = [x[1] for x in seq_lens] - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_query_len = max(query_lens) - max_kv_len = max(kv_lens) - window_size = ((sliding_window, - sliding_window) if sliding_window is not None else - (-1, -1)) - scale = head_size**-0.5 - - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - value_cache = torch.randn_like(key_cache) - # Normalize the scale of the key and value caches to mitigate - # numerical instability. - key_cache /= head_size**0.5 - value_cache /= head_size**0.5 - cu_query_lens = torch.tensor([0] + query_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) - cu_kv_lens = torch.tensor([0] + kv_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - - output = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=cu_query_lens, - cu_seqlens_k=cu_kv_lens, - max_seqlen_q=max_query_len, - max_seqlen_k=max_kv_len, - softmax_scale=scale, - causal=True, - window_size=window_size, - block_table=block_tables, - ) - - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=query_lens, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - sliding_window=sliding_window, - ) - assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index 10e7c64e34e75..c02204f16ac68 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -12,7 +12,7 @@ # "Deci/DeciLM-7b", # Broken # "tiiuae/falcon-7b", # Broken "EleutherAI/gpt-j-6b", - # "mosaicml/mpt-7b", # Broken + "mosaicml/mpt-7b", # "Qwen/Qwen1.5-0.5B" # Broken, ] diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index 664e951a89f2a..e87a1783a83f1 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -25,18 +25,18 @@ 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', - 'Zeta-5, a highly advanced robot designed for menial labor, whirred to a', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', + 'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**' ], "meta-llama/Meta-Llama-3-8B-Instruct": [ 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', + 'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of', 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 11ecb2792ea9d..f59715bd76ede 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,16 +1,20 @@ -"""Attention layer with FlashAttention.""" +"""Attention layer with Flash and PagedAttention. + +NOTE(woosuk): At the moment, this file includes a lot of duplicated code from +XFormers backend. The duplicated code will be removed once we use flash-attn or +flashinfer for all the attention operations. +""" from dataclasses import dataclass from typing import List, Optional, Tuple, Type import torch -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from vllm_flash_attn import flash_attn_varlen_func -from vllm._C import cache_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataPerStage) - -_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256] +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) class FlashAttentionBackend(AttentionBackend): @@ -34,9 +38,8 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -44,26 +47,19 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - src_key_cache = src_kv_cache[0] - dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - - src_value_cache = src_kv_cache[1] - dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + PagedAttention.copy_blocks(kv_caches, src_to_dists) @dataclass -class FlashAttentionMetadata(AttentionMetadataPerStage): +class FlashAttentionMetadata(AttentionMetadataPerStage, + PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -109,14 +105,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - class FlashAttentionImpl(AttentionImpl): """ @@ -168,15 +156,11 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if sliding_window is not None: - # NOTE(woosuk): flash-attn's sliding window does not work with - # paged KV cache. - raise ValueError( - "Sliding window is not supported in FlashAttention.") - if head_size not in _SUPPORTED_HEAD_SIZES: + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.") + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") def forward( self, @@ -187,20 +171,17 @@ def forward( attn_metadata: AttentionMetadata[FlashAttentionMetadata], kv_scale: float = 1.0, ) -> torch.Tensor: - """Forward pass with FlashAttention. + """Forward pass with FlashAttention and PagedAttention. Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ - # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." - num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -208,20 +189,16 @@ def forward( value = value.view(-1, self.num_kv_heads, self.head_size) if kv_cache is not None: - key_cache = kv_cache[0] - value_cache = kv_cache[1] + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping.flatten(), - self.kv_cache_dtype, - ) + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -241,8 +218,7 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if (kv_cache is None or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. @@ -263,32 +239,38 @@ def forward( output[:num_prefill_tokens] = out else: # prefix-enabled attention - output[:num_prefill_tokens] = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.subquery_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=prefill_meta.max_seq_len, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, + # TODO(Hai) this triton kernel has regression issue (broke) to + # deal with different data types between KV and FP8 KV cache, + # to be addressed separately. + output[:num_prefill_tokens] = PagedAttention.forward_prefix( + query, + key, + value, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window[0], ) - if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output[num_prefill_tokens:] = flash_attn_with_kvcache( - decode_query.unsqueeze(1), + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - ).squeeze(1) + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + kv_scale, + ) # Reshape the output tensor. return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 5140c3cc86a31..06f99718a4dee 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -93,20 +93,6 @@ def _which_attn_to_use( "torch.float16 or torch.bfloat16.") return _Backend.XFORMERS - if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): - logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.") - return _Backend.XFORMERS - - if block_size % 16 != 0: - logger.info("Cannot use FlashAttention-2 backend for block size not " - "divisible by 16.") - return _Backend.XFORMERS - - if sliding_window is not None: - logger.info( - "Cannot use FlashAttention-2 backend due to sliding window.") - return _Backend.XFORMERS - try: import vllm_flash_attn # noqa: F401 except ImportError: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 3f7e87c1de48c..b5e1991717b13 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -266,27 +266,20 @@ def _prepare_prompt( # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[context_len:] - if self.attn_backend.get_name() == "flash-attn": - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - # TODO(woosuk): This is a temporary fix. We should - # provide a unified interface for different backends. - block_table = seq_group_metadata.block_tables[seq_id] - else: - block_table = computed_block_nums + prefix_block_tables.append(computed_block_nums) elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: # Prefill has chunked before. block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) else: # The first prefill. - block_table = [] + prefix_block_tables.append([]) else: - block_table = [] + prefix_block_tables.append([]) # Right now, prefill start is always 0. However, this # assumption can be changed once chunked prefill is introduced. assert context_len == 0 - prefix_block_tables.append(block_table) # actual prompt lens context_lens.append(context_len) From 65bf2ac165734fb6339210c4b2b8ce68d2391b77 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Wed, 15 May 2024 14:00:10 +0900 Subject: [PATCH 047/124] [Core][2/N] Model runner refactoring part 2. Combine prepare prefill / decode to a single API (#4681) This PR combines prepare_prompt and prepare_decode into a single API. This PR also coelsce the attn metadata for prefill/decode to a single class and allow to slice them when running attn backend. It also refactors subquery_start_loc which was not refactored in the previous PR --- tests/worker/test_model_runner.py | 123 ++- vllm/attention/__init__.py | 5 +- vllm/attention/backends/abstract.py | 68 +- vllm/attention/backends/flash_attn.py | 95 ++- vllm/attention/backends/flashinfer.py | 38 +- vllm/attention/backends/rocm_flash_attn.py | 98 ++- vllm/attention/backends/torch_sdpa.py | 28 +- vllm/attention/backends/xformers.py | 92 ++- vllm/attention/layer.py | 5 +- vllm/attention/ops/paged_attn.py | 10 +- vllm/engine/arg_utils.py | 5 + .../layers/rejection_sampler.py | 1 + vllm/sequence.py | 3 +- vllm/spec_decode/batch_expansion.py | 23 +- vllm/spec_decode/multi_step_worker.py | 1 + vllm/worker/cpu_model_runner.py | 10 +- vllm/worker/embedding_model_runner.py | 130 +-- vllm/worker/model_runner.py | 772 ++++++++---------- 18 files changed, 777 insertions(+), 730 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index c2d1c5769619b..92de545acd53d 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -58,19 +58,25 @@ def test_prepare_prompt(batch_size): expected_selected_token_indices.append(selected_token_start_idx + seq_len - 1) selected_token_start_idx += seq_len - (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, - _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + slot_mapping = model_input.slot_mapping assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) # Verify input metadata is correct for prompts. device = model_runner.device - assert attn_metadata.is_prompt is True + assert attn_metadata.num_prefills > 0 + assert attn_metadata.num_decode_tokens == 0 assert torch.allclose( attn_metadata.seq_lens_tensor, torch.tensor(seq_lens, device=device, dtype=torch.int)) assert attn_metadata.seq_lens == seq_lens - assert attn_metadata.max_seq_len == max(seq_lens) + assert attn_metadata.max_prefill_seq_len == max(seq_lens) + assert attn_metadata.max_decode_seq_len == 0 # Test subquery start locs. start_idx = 0 @@ -79,11 +85,11 @@ def test_prepare_prompt(batch_size): start_idx += seq_len start_loc.append(start_idx) assert torch.allclose( - attn_metadata.subquery_start_loc, + attn_metadata.query_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device)) # Test seq start locs. Note that for normal prefill it is - # equivalent to subquery_start_loc. + # equivalent to query_start_loc. start_idx = 0 seq_start_loc = [start_idx] for seq_len in seq_lens: @@ -123,7 +129,7 @@ def test_prepare_prompt(batch_size): device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) - assert input_tokens == input_positions + torch.allclose(input_tokens, input_positions) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, @@ -144,14 +150,18 @@ def test_prepare_decode_cuda_graph(batch_size): enable_chunked_prefill=False, ) - seq_lens = [] + context_lens = [] seq_group_metadata_list = [] + # Assume each seq group finishes prefill. for i in range(batch_size): # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) - seq_data = list(range(seq_len)) + context_len = i % (model_runner.block_size - 1) + 1 + context_lens.append(context_len) + seq_data = list(range(context_len)) seq_data = SequenceData(seq_data) + seq_data.update_num_computed_tokens(context_len) + # Append one token ID since prefill is finished. + seq_data.append_token_id(1, 0) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, @@ -162,18 +172,45 @@ def test_prepare_decode_cuda_graph(batch_size): assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) - input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( - model_runner._prepare_decode(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens, input_positions, attn_metadata, slot_mapping = ( + model_input.input_tokens, model_input.input_positions, + model_input.attn_metadata, model_input.slot_mapping) assert len(slot_mapping) == len(input_tokens) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device - assert attn_metadata.is_prompt is False - assert attn_metadata.seq_lens is None - assert attn_metadata.subquery_start_loc is None - assert attn_metadata.seq_start_loc is None - assert attn_metadata.max_seq_len == max(seq_lens) + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_prefill_tokens == 0 + seq_lens = [context_len + 1 for context_len in context_lens] + # seq_lens are padded to expected_bs + for _ in range(expected_bs - len(seq_lens)): + seq_lens.append(1) + assert attn_metadata.seq_lens == seq_lens + start_idx = 0 + start_loc = [start_idx] + for _ in context_lens: + # decode has only 1 token for query. + start_idx += 1 + start_loc.append(start_idx) + assert torch.allclose( + attn_metadata.query_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) + + start_idx = 0 + seq_start_loc = [start_idx] + for seq_len in seq_lens: + start_idx += seq_len + seq_start_loc.append(start_idx) + assert torch.allclose( + attn_metadata.seq_start_loc, + torch.tensor(seq_start_loc, dtype=torch.int32, device=device)) + + assert torch.allclose( + attn_metadata.context_lens_tensor, + torch.tensor(context_lens, dtype=torch.int, device=device)) + assert attn_metadata.max_decode_seq_len == max(seq_lens) assert torch.allclose( attn_metadata.seq_lens_tensor[:len(seq_lens)], torch.tensor(seq_lens, dtype=torch.int, device=device)) @@ -185,23 +222,23 @@ def test_prepare_decode_cuda_graph(batch_size): # It is padded up to assert attn_metadata.block_tables.shape[1] == ( model_runner.get_max_block_per_batch()) - # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is True assert len(input_tokens) == expected_bs assert len(input_positions) == expected_bs - assert input_tokens == input_positions + torch.allclose(input_tokens, input_positions) # Verify Sampling expected_selected_token_indices = [] selected_token_start_idx = 0 - for seq_len in seq_lens: + for _ in context_lens: expected_selected_token_indices.append(selected_token_start_idx) selected_token_start_idx += 1 sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, - query_lens=seq_lens, + # query lens is all 1 for decode. + query_lens=[1 for _ in range(len(context_lens))], device=model_runner.device, pin_memory=model_runner.pin_memory) actual = sampling_metadata.selected_token_indices @@ -220,15 +257,27 @@ def test_empty_seq_group(): enforce_eager=False, ) seq_group_metadata_list = [] - input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( - model_runner._prepare_decode(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens, input_positions, attn_metadata, slot_mapping = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.slot_mapping, + ) assert len(input_tokens) == 0 assert len(input_positions) == 0 assert attn_metadata is None assert len(slot_mapping) == 0 - (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, - _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata, slot_mapping, + return_seq_lens) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.slot_mapping, + model_input.seq_lens, + ) assert len(input_tokens) == 0 assert len(input_positions) == 0 assert attn_metadata is None @@ -285,9 +334,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): # Add decode requests for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - prompt_toks = list(range(seq_len)) + context_len = i % (model_runner.block_size - 1) + 1 + prompt_toks = list(range(context_len)) seq_data = SequenceData(prompt_toks) + seq_data.append_token_id(1, 0) + seq_data.update_num_computed_tokens(context_len) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, @@ -308,23 +359,17 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): assert len(attn_metadata.slot_mapping) == len(input_tokens) assert len(input_positions) == len(input_tokens) assert attn_metadata.num_prefills == prefill_batch_size - if enforce_eager: - assert attn_metadata.num_decode_tokens == decode_batch_size - else: - assert attn_metadata.num_decode_tokens == _get_graph_batch_size( - decode_batch_size) + assert attn_metadata.num_decode_tokens == decode_batch_size assert attn_metadata.num_prefill_tokens == sum(seq_lens) # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. - prefill_meta = model_runner._prepare_prompt( - prefill_metadata_list).attn_metadata - decode_meta = model_runner._prepare_decode( - decode_metadata_list).attn_metadata + attn_metadata = model_runner._prepare_model_input( + seq_group_metadata_list).attn_metadata - for attr_expected, attr_actual in zip(vars(prefill_meta), + for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata), vars(prefill_meta_actual)): assert attr_expected[1] == attr_actual[1] - for attr_expected, attr_actual in zip(vars(decode_meta), + for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata), vars(decode_meta_actual)): assert attr_expected[1] == attr_actual[1] diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 088f48def7668..f6bce9a187c64 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,6 +1,5 @@ from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -8,6 +7,6 @@ "Attention", "AttentionBackend", "AttentionMetadata", - "AttentionMetadataPerStage", + "Attention", "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 98d70fcab1a18..94ab64de30a94 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -21,7 +21,7 @@ def get_impl_cls() -> Type["AttentionImpl"]: @staticmethod @abstractmethod - def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage": + def make_metadata(*args, **kwargs) -> "AttentionMetadata": raise NotImplementedError @staticmethod @@ -53,8 +53,34 @@ def copy_blocks( @dataclass -class AttentionMetadataPerStage: - """Attention metadata for a specific stage. I.e., prefill or decode.""" +class AttentionMetadata: + """Attention metadata for prefill and decode batched together.""" + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # Number of decode tokens. Note that it is equivalent to the number of + # decode requests. + num_decode_tokens: int + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + + @property + @abstractmethod + def prefill_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run prefill + attention.""" + pass + + @property + @abstractmethod + def decode_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run decode + attention.""" + pass def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None @@ -70,40 +96,10 @@ def asdict_zerocopy(self, } -T = TypeVar("T", bound=AttentionMetadataPerStage) - - -@dataclass -class AttentionMetadata(Generic[T]): - """Attention metadata for prefill and decode batched together.""" - # Total number of prefill requests. - num_prefills: int - # Number of prefill tokens. - num_prefill_tokens: int - # Number of decode tokens. Note that it is equivalent to the number of - # decode requests. - num_decode_tokens: int - # The attention metadata for prefill requests in a batch. - # None if there's no prefill requests in a batch. - prefill_metadata: Optional[T] - # The attention metadata for decode requests in a batch. - # None if there's no decode requests in a batch. - decode_metadata: Optional[T] - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor - - def __post_init__(self): - if self.num_prefill_tokens > 0: - assert self.num_prefills > 0 - assert self.prefill_metadata is not None - if self.num_decode_tokens > 0: - assert self.decode_metadata is not None +T = TypeVar("T", bound=AttentionMetadata) -class AttentionImpl(ABC): +class AttentionImpl(ABC, Generic[T]): @abstractmethod def __init__( @@ -125,7 +121,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, + attn_metadata: T, kv_scale: float = 1.0, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index f59715bd76ede..5d1f65819ed4e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -11,8 +11,7 @@ from vllm_flash_attn import flash_attn_varlen_func from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -58,8 +57,7 @@ def copy_blocks( @dataclass -class FlashAttentionMetadata(AttentionMetadataPerStage, - PagedAttentionMetadata): +class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -67,9 +65,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] @@ -84,14 +79,18 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # |-------------------- seq_len ----------------------| # |-- query_len ---| - # Maximum query length in the batch. + # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] - # Maximum sequence length in the batch. - max_seq_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. @@ -105,6 +104,70 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = FlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = FlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + class FlashAttentionImpl(AttentionImpl): """ @@ -168,7 +231,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata[FlashAttentionMetadata], + attn_metadata: FlashAttentionMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -228,8 +291,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_seq_len, - max_seqlen_k=prefill_meta.max_seq_len, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, @@ -249,7 +312,7 @@ def forward( key_cache, value_cache, prefill_meta.block_tables, - prefill_meta.subquery_start_loc, + prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens_tensor, prefill_meta.max_query_len, @@ -264,7 +327,7 @@ def forward( value_cache, decode_meta.block_tables, decode_meta.seq_lens_tensor, - decode_meta.max_seq_len, + decode_meta.max_decode_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 92d0fe0487516..5f9fd586fb70e 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -8,8 +8,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) class FlashInferBackend(AttentionBackend): @@ -56,9 +55,10 @@ def get_supported_head_sizes() -> List[int]: @dataclass -class FlashInferMetadata(AttentionMetadataPerStage): - - is_prompt: bool +class FlashInferMetadata(AttentionMetadata): + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int use_cuda_graph: bool = False @@ -67,7 +67,6 @@ class FlashInferMetadata(AttentionMetadataPerStage): # Metadata for the prefill stage since we still # use flash attention for prefill. seq_start_loc: Optional[torch.Tensor] = None - max_seq_len: Optional[int] = None block_tables: Optional[torch.Tensor] = None # Metadata for the decode stage @@ -113,7 +112,8 @@ def __post_init__(self): # When using flashinfer, we are also creating the FlashInferMetadata, # which will also call post_init by default, here we want to skip the # post_init if it's the prefill phase. - if not self.is_prompt: + if self.num_prefills == 0: + assert self.num_decode_tokens > 0 self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, "NHD") self.decode_wrapper.begin_forward( @@ -138,6 +138,24 @@ def asdict_zerocopy(self, skip_fields.add('decode_wrapper') return super().asdict_zerocopy(skip_fields) + @property + def prefill_metadata(self) -> Optional["FlashInferMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self + + return None + + @property + def decode_metadata(self) -> Optional["FlashInferMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + class FlashInferImpl(AttentionImpl): @@ -172,7 +190,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[FlashInferMetadata], + attn_metadata: FlashInferMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: assert kv_scale == 1.0 @@ -208,8 +226,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_seq_len, - max_seqlen_k=prefill_meta.max_seq_len, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 539585b46c7aa..1a94dc3596358 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -6,8 +6,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -56,8 +55,7 @@ def copy_blocks( @dataclass -class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, - PagedAttentionMetadata): +class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -65,9 +63,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] @@ -82,14 +77,18 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # |-------------------- seq_len ----------------------| # |-- query_len ---| - # Maximum query length in the batch. + # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] - # Maximum sequence length in the batch. - max_seq_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. @@ -102,6 +101,69 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] + _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = ROCmFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = ROCmFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata class ROCmFlashAttentionImpl(AttentionImpl): @@ -198,7 +260,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata], + attn_metadata: ROCmFlashAttentionMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -266,8 +328,8 @@ def forward( None, prefill_meta.seq_start_loc, prefill_meta.seq_start_loc, - prefill_meta.max_seq_len, - prefill_meta.max_seq_len, + prefill_meta.max_prefill_seq_len, + prefill_meta.max_prefill_seq_len, True, self.scale, ) @@ -290,8 +352,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_seq_len, - max_seqlen_k=prefill_meta.max_seq_len, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, softmax_scale=self.scale, causal=True, ) @@ -308,7 +370,7 @@ def forward( key_cache, value_cache, prefill_meta.block_tables, - prefill_meta.subquery_start_loc, + prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens_tensor, prefill_meta.max_query_len, @@ -324,7 +386,7 @@ def forward( value_cache, decode_meta.block_tables, decode_meta.seq_lens_tensor, - decode_meta.max_seq_len, + decode_meta.max_decode_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 2dd72a00c6e30..a3f72b9c94566 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,8 +7,7 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -54,8 +53,7 @@ def copy_blocks( @dataclass -class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, - AttentionMetadataPerStage): +class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts @@ -72,8 +70,26 @@ def __post_init__(self): # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[torch.Tensor]] = None + @property + def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self -class TorchSDPABackendImpl(AttentionImpl): + return None + + @property + def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + + +class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): def __init__( self, @@ -200,7 +216,7 @@ def forward( value_cache, attn_metadata.block_tables, attn_metadata.seq_lens_tensor, - attn_metadata.max_seq_len, + attn_metadata.max_decode_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index cb2028553461f..fc46af054de4f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -9,8 +9,7 @@ LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -59,7 +58,7 @@ def copy_blocks( @dataclass -class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): +class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for XFormersbackend. NOTE: Any python object stored here is not updated when it is @@ -67,9 +66,6 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] @@ -83,15 +79,19 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # |-------------------- seq_len ----------------------| # |-- query_len ---| - # Maximum query length in the batch. + # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] # FIXME: It is for flash attn. - # Maximum sequence length in the batch. - max_seq_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] # FIXME: It is for flash attn. # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is @@ -105,6 +105,8 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + _cached_prefill_metadata: Optional["XFormersMetadata"] = None + _cached_decode_metadata: Optional["XFormersMetadata"] = None def __post_init__(self): # Set during the execution of the first attention op. @@ -114,8 +116,68 @@ def __post_init__(self): # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[AttentionBias]] = None - -class XFormersImpl(AttentionImpl): + @property + def prefill_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + + self._cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None, + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class XFormersImpl(AttentionImpl[XFormersMetadata]): """ If the input tensors contain prompt tokens, the layout is as follows: |<--------------- num_prefill_tokens ----------------->| @@ -176,7 +238,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[XFormersMetadata], + attn_metadata: "XFormersMetadata", kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -244,7 +306,7 @@ def forward( key_cache, value_cache, prefill_meta.block_tables, - prefill_meta.subquery_start_loc, + prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens_tensor, prefill_meta.max_query_len, @@ -261,7 +323,7 @@ def forward( value_cache, decode_meta.block_tables, decode_meta.seq_lens_tensor, - decode_meta.max_seq_len, + decode_meta.max_decode_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 8a872dba8c877..126692d8c9b40 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,8 +4,7 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import (AttentionMetadata, - AttentionMetadataPerStage) +from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig @@ -57,7 +56,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[AttentionMetadataPerStage], + attn_metadata: AttentionMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: return self.impl.forward(query, key, value, kv_cache, attn_metadata, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 3c010b67b3120..30feaa4da254d 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -16,8 +16,8 @@ class PagedAttentionMetadata: # (batch_size,). The length of sequences (entire tokens seen so far) per # sequence. seq_lens_tensor: Optional[torch.Tensor] - # Maximum sequence length in the batch. - max_seq_len: Optional[int] + # Maximum sequence length in the batch. 0 if it is prefill-only batch. + max_decode_seq_len: int # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks @@ -166,7 +166,7 @@ def forward_prefix( key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - subquery_start_loc: torch.Tensor, + query_start_loc: torch.Tensor, seq_lens_tensor: torch.Tensor, context_lens: torch.Tensor, max_query_len: int, @@ -182,8 +182,8 @@ def forward_prefix( key_cache, value_cache, block_tables, - # subquery_start_loc is (batch_size + 1,) - subquery_start_loc[:-1], + # query_start_loc is (batch_size + 1,) + query_start_loc[:-1], seq_lens_tensor, context_lens, max_query_len, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 195d9e1b33e3c..bd44c2470182b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -618,6 +618,11 @@ def create_engine_config(self, ) -> EngineConfig: decoding_config = DecodingConfig( guided_decoding_backend=self.guided_decoding_backend) + if (model_config.get_sliding_window() is not None + and scheduler_config.chunked_prefill_enabled): + raise ValueError( + "Chunked prefill is not supported with sliding window.") + return EngineConfig(model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index b5f1e55d0e839..1f2ab7e2870ca 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -122,6 +122,7 @@ def forward( draft_token_ids, bonus_token_ids, ) + return output_token_ids def _batch_modified_rejection_sampling( diff --git a/vllm/sequence.py b/vllm/sequence.py index 12e930c27173e..aa759448d82b1 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -654,8 +654,9 @@ def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @property - def token_chunk_size(self) -> Optional[int]: + def token_chunk_size(self) -> int: """Return the number of tokens to be processed (chunk size).""" + assert self._token_chunk_size is not None return self._token_chunk_size diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index d5fd96907ddd7..7792f3a3425cc 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -293,21 +293,30 @@ def _create_single_target_seq_group_metadata( prompt_token_ids = seq_data.get_prompt_token_ids() new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] + new_seq_data_dict = { + target_seq_id: + SequenceData( + prompt_token_ids=prompt_token_ids, + output_token_ids=new_output_token_ids, + ), + } + # This is a hack. Technically, spec decoding should compute + # num_lookahead slots at one shot, but instead, it expands the batch + # and evaluate one by one right now. context_len is seq_len - 1 because + # the kv cache is filled by a previous batch in the batch expansion. + for data in new_seq_data_dict.values(): + data.update_num_computed_tokens(data.get_len() - 1) + return SequenceGroupMetadata( request_id=seq_group_metadata.request_id, is_prompt=seq_group_metadata.is_prompt, - seq_data={ - target_seq_id: - SequenceData( - prompt_token_ids=prompt_token_ids, - output_token_ids=new_output_token_ids, - ), - }, + seq_data=new_seq_data_dict, sampling_params=seq_group_metadata.sampling_params, block_tables={ target_seq_id: seq_group_metadata.block_tables[seq_id], }, lora_request=None, + token_chunk_size=1, ) def _split_scoring_output( diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 20098ebaeea32..b5a805278d273 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -114,6 +114,7 @@ def _append_new_tokens( token_logprob = seq_output.logprobs[token_id] seq.append_token_id(token_id, token_logprob.logprob) + seq.update_num_computed_tokens(1) def _shallow_copy_inputs( self, seq_group_metadata_list: List[SequenceGroupMetadata] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 0a0b0d70cfe21..bc88f2c5bed6c 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -159,12 +159,10 @@ def _prepare_prompt( is_prompt=True, seq_lens=seq_lens, seq_lens_tensor=None, - max_seq_len=None, + max_decode_seq_len=None, num_prefills=len(seq_lens), num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, - prefill_metadata=None, - decode_metadata=None, block_tables=torch.tensor([]), slot_mapping=slot_mapping, ) @@ -213,7 +211,7 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) - max_seq_len = max(seq_lens) + max_decode_seq_len = max(seq_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, @@ -243,12 +241,10 @@ def _prepare_decode( slot_mapping=slot_mapping, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, - max_seq_len=max_seq_len, + max_decode_seq_len=max_decode_seq_len, num_prefill_tokens=0, num_decode_tokens=len(input_tokens), num_prefills=0, - prefill_metadata=None, - decode_metadata=None, block_tables=block_tables, ) return ( diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index d04bebbdc31b6..91f30978ead87 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -13,7 +13,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import BatchType, ModelRunner +from vllm.worker.model_runner import ModelRunner logger = init_logger(__name__) @@ -88,85 +88,24 @@ def prepare_input_tensors( ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: - prefill_reqs = [] - decode_reqs = [] - for seq_group_meta in seq_group_metadata_list: - if seq_group_meta.is_prompt: - prefill_reqs.append(seq_group_meta) - else: - decode_reqs.append(seq_group_meta) - # Prepare input tensors. ( input_tokens, input_positions, - prefill_attn_metadata, - prompt_lens, - subquery_lens, - lora_index_mapping, - lora_prompt_mapping, + attn_metadata, + seq_lens, + _, + lora_mapping, lora_requests, multi_modal_input, slot_mapping, - ) = self._prepare_prompt(prefill_reqs) - ( - decode_input_tokens, - decode_input_positions, - decode_attn_metadata, - decode_lora_index_mapping, - decode_lora_prompt_mapping, - decode_lora_requests, - decode_slot_mapping, - ) = self._prepare_decode(decode_reqs) - + num_prefill_tokens, + num_decode_tokens, + num_prefills, + ) = self._prepare_model_input(seq_group_metadata_list) # Prepare PoolingMetadata pooling_metadata = self._prepare_pooling(seq_group_metadata_list, - prompt_lens) - - if not self.scheduler_config.chunked_prefill_enabled: - assert (len(prefill_reqs) and len(decode_reqs)) == 0 - - num_prefills = len(prompt_lens) - num_prefill_tokens = len(input_tokens) - num_decode_tokens = len(decode_input_tokens) - - # Coalesce tensors. Note that attn_metadata is currently not - # coalesced for simplicity. - input_tokens.extend(decode_input_tokens) - input_positions.extend(decode_input_positions) - slot_mapping.extend(decode_slot_mapping) - lora_index_mapping.extend(decode_lora_index_mapping) - lora_prompt_mapping.extend(decode_lora_prompt_mapping) - lora_requests.update(decode_lora_requests) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - - if self.lora_config: - lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) - else: - lora_mapping = None - - # Broadcast the metadata. - # If batch contains both prefill and decode, it sends 2 broadcasts. - # If it only contains 1 type, it triggers a single broadcast. - if (prefill_attn_metadata is not None - and decode_attn_metadata is not None): - batch_type = BatchType.MIXED - elif prefill_attn_metadata is not None: - batch_type = BatchType.PREFILL - else: - batch_type = BatchType.DECODE + seq_lens) metadata_dict = { "input_tokens": input_tokens, @@ -178,65 +117,26 @@ def prepare_input_tensors( "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, "num_prefills": num_prefills, - "batch_type": batch_type, } - if prefill_attn_metadata is not None: - metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) - else: - assert decode_attn_metadata is not None - metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) + if attn_metadata: + metadata_dict.update(attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) - - # Broadcast decode attn metadata for mixed batch type. - # The additional broadcast costs 300us overhead on 4 A10 GPUs. - # We can potentially reduce the overhead by coelescing tensors. - if batch_type == BatchType.MIXED: - assert decode_attn_metadata is not None - metadata_dict = decode_attn_metadata.asdict_zerocopy() - broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") - slot_mapping = metadata_dict.pop("slot_mapping") - num_prefills = metadata_dict.pop("num_prefills") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_input = metadata_dict.pop("multi_modal_input") - num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") - num_decode_tokens = metadata_dict.pop("num_decode_tokens") - batch_type = metadata_dict.pop("batch_type") - - # Create an attention metadata. - prefill_attn_metadata = None - decode_attn_metadata = None - if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: - prefill_attn_metadata = self.attn_backend.make_metadata( + if metadata_dict: + attn_metadata = self.attn_backend.make_metadata( **metadata_dict) else: - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - + attn_metadata = None pooling_metadata = PoolingMetadata(seq_groups=None, seq_data=None, prompt_lens=None) - # if it is a mixed batch, decode attn_metadata is broadcasted - # separately. - if batch_type == BatchType.MIXED: - metadata_dict = broadcast_tensor_dict(src=0) - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - - attn_metadata = AttentionMetadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - prefill_metadata=prefill_attn_metadata, - decode_metadata=decode_attn_metadata, - ) - return (input_tokens, input_positions, attn_metadata, pooling_metadata, lora_requests, lora_mapping, multi_modal_input) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b5e1991717b13..dcdd4b962454e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,13 +1,11 @@ import time -from enum import IntEnum from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union import numpy as np import torch import torch.nn as nn -from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, - get_attn_backend) +from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) @@ -37,66 +35,38 @@ ] -class PreparePromptMetadata(NamedTuple): - input_tokens: List[int] - input_positions: List[int] - attn_metadata: Optional[AttentionMetadataPerStage] +class ModelInput(NamedTuple): + input_tokens: torch.Tensor + input_positions: torch.Tensor + attn_metadata: Optional[AttentionMetadata] seq_lens: List[int] query_lens: List[int] - lora_index_mapping: List[int] - lora_prompt_mapping: List[int] + lora_mapping: Optional[LoRAMapping] lora_requests: Set[LoRARequest] multi_modal_input: Optional[torch.Tensor] - slot_mapping: List[int] + slot_mapping: torch.Tensor + num_prefill_tokens: int + num_decode_tokens: int + num_prefills: int @classmethod - def empty(cls): - return PreparePromptMetadata( - input_tokens=[], - input_positions=[], + def empty(cls, device): + return ModelInput( + input_tokens=torch.empty(0, device=device), + input_positions=torch.empty(0, device=device), attn_metadata=None, seq_lens=[], query_lens=[], - lora_index_mapping=[], - lora_prompt_mapping=[], + lora_mapping=None, lora_requests=set(), multi_modal_input=None, - slot_mapping=[], - ) - - -class PrepareDecodeMetadata(NamedTuple): - input_tokens: List[int] - input_positions: List[int] - attn_metadata: Optional[AttentionMetadata] - lora_index_mapping: List[int] - lora_prompt_mapping: List[int] - lora_requests: Set[LoRARequest] - slot_mapping: List[int] - - @classmethod - def empty(cls): - return PrepareDecodeMetadata( - input_tokens=[], - input_positions=[], - attn_metadata=None, - lora_index_mapping=[], - lora_prompt_mapping=[], - lora_requests=set(), - slot_mapping=[], + slot_mapping=torch.empty(0, device=device), + num_prefill_tokens=0, + num_decode_tokens=0, + num_prefills=0, ) -# How batches are constructed. -class BatchType(IntEnum): - # Every batch is prefill. - PREFILL = 0 - # Every batch is decode. - DECODE = 1 - # Batch is a mixture of prefill and decode. - MIXED = 2 - - class ModelRunner: def __init__( @@ -216,10 +186,22 @@ def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size - def _prepare_prompt( + def _prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> PreparePromptMetadata: + ) -> ModelInput: + """Prepare the model input based on a given sequence group. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -228,212 +210,16 @@ def _prepare_prompt( lora_requests: Set[LoRARequest] = set() seq_lens: List[int] = [] + prefill_seq_lens: List[int] = [] + decode_seq_lens: List[int] = [] context_lens: List[int] = [] query_lens: List[int] = [] - prefix_block_tables: List[List[int]] = [] - multi_modal_input_list: List[torch.Tensor] = [] - - if len(seq_group_metadata_list) == 0: - return PreparePromptMetadata.empty() - - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - computed_block_nums = seq_group_metadata.computed_block_nums - if (self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled - and not (computed_block_nums is None - or computed_block_nums == [])): - raise RuntimeError( - "chunked prefill cannot be used with prefix caching " - "now.") - - token_chunk_size = seq_group_metadata.token_chunk_size - seq_data = seq_group_metadata.seq_data[seq_id] - context_len = seq_data.get_num_computed_tokens() - # We should use get_len here because in case of preemption - # it contains output tokens. - seq_len = min(seq_data.get_len(), context_len + token_chunk_size) - prompt_tokens = seq_data.get_token_ids()[context_len:seq_len] - seq_lens.append(seq_len) - - # NOTE: This only works for oooooooxxx style attention. - if computed_block_nums is not None and len( - computed_block_nums) > 0 and self.sliding_window is None: - # Prefix is not supported with sliding_window - context_len = len(computed_block_nums) * self.block_size - prompt_tokens = prompt_tokens[context_len:] - prefix_block_tables.append(computed_block_nums) - elif self.scheduler_config.chunked_prefill_enabled: - if seq_group_metadata.block_tables is not None: - # Prefill has chunked before. - block_table = seq_group_metadata.block_tables[seq_id] - prefix_block_tables.append(block_table) - else: - # The first prefill. - prefix_block_tables.append([]) - else: - prefix_block_tables.append([]) - # Right now, prefill start is always 0. However, this - # assumption can be changed once chunked prefill is introduced. - assert context_len == 0 - - # actual prompt lens - context_lens.append(context_len) - query_lens.append(seq_len - context_len) - - input_tokens.extend(prompt_tokens) - # NOTE(woosuk): Here we assume that the first token in the prompt - # is always the first token in the sequence. - input_positions.extend(list(range(context_len, seq_len))) - lora_id = seq_group_metadata.lora_int_id - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - - lora_index_mapping += [lora_id] * (seq_len - context_len) - lora_prompt_mapping.extend([lora_id] * ( - seq_len - context_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs else 1)) - - if seq_group_metadata.multi_modal_data: - multi_modal_input_list.append( - seq_group_metadata.multi_modal_data.data) - - if _is_block_tables_empty(seq_group_metadata.block_tables): - # During memory profiling, the block tables are not initialized - # yet. In this case, we just use a dummy slot mapping. - # In embeddings, the block tables are {seq_id: None}. - slot_mapping.extend([_PAD_SLOT_ID] * seq_len) - continue - - # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] - - # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, seq_len - sliding_window). - # For example, if the prompt len is 10, sliding window is 8, and - # block size is 4, the first two tokens are masked and the slot - # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - assert context_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention") - start_idx = max(0, seq_len - self.sliding_window) - - for i in range(context_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - max_query_len = max(query_lens) - max_seq_len = max(seq_lens) - assert max_query_len > 0 - - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - - if multi_modal_input_list: - assert self.vision_language_config, ( - "Multi-modal inputs are only supported by " - "vision language models.") - multi_modal_input = torch.cat(multi_modal_input_list, - dim=0).to(self.device) - else: - multi_modal_input = None - - # Prepare prefix block tables - max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) - block_tables = make_tensor_with_pad( - prefix_block_tables, - max_len=max_prompt_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) - - # Query length can be shorter than key (i.e., prompt) when prefill - # is chunked or prefix cached. - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - torch.cumsum(query_lens_tensor, - dim=0, - dtype=subquery_start_loc.dtype, - out=subquery_start_loc[1:]) - - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - - if self.attn_backend.get_name() == "flashinfer": - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - use_cuda_graph=False, - seq_start_loc=seq_start_loc, - max_seq_len=max_seq_len, - block_tables=block_tables) - else: - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_seq_len=max_seq_len, - subquery_start_loc=subquery_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - ) - - return PreparePromptMetadata( - input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_index_mapping=lora_index_mapping, - lora_prompt_mapping=lora_prompt_mapping, - lora_requests=lora_requests, - multi_modal_input=multi_modal_input, - slot_mapping=slot_mapping, - ) - - def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> PrepareDecodeMetadata: - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - seq_lens: List[int] = [] block_tables: List[List[int]] = [] - lora_index_mapping: List[int] = [] - lora_prompt_mapping: List[int] = [] - lora_requests: Set[LoRARequest] = set() + multi_modal_input_list: List[torch.Tensor] = [] + decode_only = True + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = 0 # The following fields are only for flashinfer # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout @@ -454,60 +240,186 @@ def _prepare_decode( paged_kv_last_page_len: List[int] = [] if len(seq_group_metadata_list) == 0: - return PrepareDecodeMetadata.empty() + return ModelInput.empty(self.device) for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - assert seq_group_metadata.token_chunk_size == 1 - seq_ids = list(seq_group_metadata.seq_data.keys()) - lora_id = seq_group_metadata.lora_int_id - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) + is_prompt = seq_group_metadata.is_prompt for seq_id in seq_ids: + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append(generation_token) + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + context_len = seq_data.get_len() - 1 + + seq_len = min( + seq_data.get_len(), + context_len + seq_group_metadata.token_chunk_size) + if is_prompt: + tokens = seq_data.get_token_ids()[context_len:seq_len] + else: + # Optimization. get_token_ids requires the entire copy of + # tokens. + tokens = [seq_data.get_last_token_id()] + + # Prefix cache was hit. + # Prefix is not supported with sliding_window + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + and is_prompt) + + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + context_len = len(computed_block_nums) * self.block_size + tokens = tokens[context_len:] + if self.attn_backend.get_name() == "flash-attn": + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + # TODO(woosuk): This is a temporary fix. We should + # provide a unified interface for different backends. + block_table = seq_group_metadata.block_tables[seq_id] + else: + block_table = computed_block_nums + elif (self.scheduler_config.chunked_prefill_enabled + or not is_prompt): + if seq_group_metadata.block_tables is not None: + # chunked prefill or decode + block_table = seq_group_metadata.block_tables[seq_id] + if self.sliding_window is not None: + # chunked prefill doesn't support sliding window. + assert (not self.scheduler_config. + chunked_prefill_enabled) + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + + if self.attn_backend.get_name() == "flashinfer": + paged_kv_indices.extend(block_table) + paged_kv_indptr.append(paged_kv_indptr[-1] + + len(block_table)) + last_page_len = seq_data.get_len( + ) % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + paged_kv_last_page_len.append(last_page_len) + else: + # Only happens when memory profiling runs. + block_table = [] + else: + # Prefill without chunked prefill or memory profiling. + block_table = [] + block_tables.append(block_table) - seq_len = seq_data.get_len() - position = seq_len - 1 - input_positions.append(position) + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if (self.sliding_window is not None and not is_prompt): + seq_len = min(seq_len, self.sliding_window) + context_len = seq_len - 1 - seq_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) seq_lens.append(seq_len) + context_lens.append(context_len) + query_len = seq_len - context_len + query_lens.append(query_len) + input_tokens.extend(tokens) + input_positions.extend(list(range(context_len, seq_len))) + lora_id = seq_group_metadata.lora_int_id + + if is_prompt: + assert len(seq_ids) == 1 + num_prefills += 1 + num_prefill_tokens += len(tokens) + decode_only = False + prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + num_decode_tokens += query_len + decode_seq_lens.append(seq_len) + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping += [lora_id] * (seq_len - context_len) + lora_prompt_mapping.extend( + [lora_id] * + (seq_len - + context_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + else 1)) + + if seq_group_metadata.multi_modal_data: + multi_modal_input_list.append( + seq_group_metadata.multi_modal_data.data) + + if _is_block_tables_empty(seq_group_metadata.block_tables): + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + continue + # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] - block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - lora_index_mapping.append(lora_id) - lora_prompt_mapping.append(lora_id) + # Mask the [0, start_idx) tokens of the prompt with + # _PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window // - self.block_size) - block_table = block_table[-sliding_window_blocks:] - block_tables.append(block_table) + if is_prompt: + assert context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention") + # It is an optimization. When it is decoding, it is always + # 0. When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - self.sliding_window) + + for i in range(context_len, seq_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) - paged_kv_indices.extend(block_table) - paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table)) - last_page_len = seq_data.get_len() % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - paged_kv_last_page_len.append(last_page_len) + batch_size = len(input_tokens) + max_query_len = max(query_lens) + max_prefill_seq_len = max(prefill_seq_lens, default=0) + max_decode_seq_len = max(decode_seq_lens, default=0) - # vLLM uses cuda graph only for decoding requests. + # If cuda graph can be used, pad tensors accordingly. # See `capture_model` API for more details. - # For decoding requests, batch_size == input_tokens. - batch_size = len(input_tokens) - max_seq_len = max(seq_lens) - use_captured_graph = (not self.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_seq_len <= self.max_seq_len_to_capture) + # vLLM uses cuda graph only for decoding requests. + use_captured_graph = ( + decode_only and not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_decode_seq_len <= self.max_seq_len_to_capture) if use_captured_graph: graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size @@ -519,18 +431,9 @@ def _prepare_decode( block_tables.append([]) lora_index_mapping.append(0) batch_size = graph_batch_size - - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) + num_decode_tokens = batch_size if use_captured_graph: - # When using cuda-graph all these tensors should be - # padded. - assert seq_lens_tensor.shape[0] == len(input_tokens) - assert seq_lens_tensor.shape[0] == len(input_positions) - assert seq_lens_tensor.shape[0] == len(slot_mapping) - # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.graph_block_tables[:batch_size] @@ -548,6 +451,57 @@ def _prepare_decode( dtype=torch.int, device=self.device, ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + + if multi_modal_input_list: + assert self.vision_language_config, ( + "Multi-modal inputs are only supported by " + "vision language models.") + multi_modal_input = torch.cat(multi_modal_input_list, + dim=0).to(self.device) + else: + multi_modal_input = None + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) if self.attn_backend.get_name() == "flashinfer": if not hasattr(self, "flashinfer_workspace_buffer"): @@ -555,53 +509,75 @@ def _prepare_decode( # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html self.flashinfer_workspace_buffer = torch.empty( 16 * 1024 * 1024, dtype=torch.uint8, device=self.device) - paged_kv_indptr = torch.tensor(paged_kv_indptr, - dtype=torch.int, - device=self.device) - paged_kv_indices = torch.tensor(paged_kv_indices, - dtype=torch.int, - device=self.device) - paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len, + paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, dtype=torch.int, device=self.device) + paged_kv_indices_tensor = torch.tensor(paged_kv_indices, + dtype=torch.int, + device=self.device) + paged_kv_last_page_len_tensor = torch.tensor( + paged_kv_last_page_len, dtype=torch.int, device=self.device) kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, self.model_config.dtype) - attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, use_cuda_graph=False, + max_prefill_seq_len=max_prefill_seq_len, + block_tables=block_tables, workspace_buffer=self.flashinfer_workspace_buffer, - paged_kv_indptr=paged_kv_indptr, - paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len=paged_kv_last_page_len, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, num_qo_heads=self.model_config.get_num_attention_heads( self.parallel_config), num_kv_heads=self.model_config.get_num_kv_heads( self.parallel_config), head_dim=self.model_config.get_head_size(), - page_size=self.block_size, + page_size=16, + seq_start_loc=seq_start_loc, data_type=kv_cache_dtype) else: attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, - seq_lens=None, + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, - max_query_len=None, - max_seq_len=max_seq_len, - subquery_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, ) - return PrepareDecodeMetadata( - input_tokens=input_tokens, - input_positions=input_positions, + + if self.lora_config: + lora_mapping = LoRAMapping( + lora_index_mapping, + lora_prompt_mapping, + ) + else: + lora_mapping = None + + return ModelInput( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, attn_metadata=attn_metadata, - lora_index_mapping=lora_index_mapping, - lora_prompt_mapping=lora_prompt_mapping, + seq_lens=seq_lens, + query_lens=query_lens, + lora_mapping=lora_mapping, lora_requests=lora_requests, - slot_mapping=slot_mapping, + multi_modal_input=multi_modal_input, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, ) def prepare_input_tensors( @@ -610,85 +586,25 @@ def prepare_input_tensors( ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: - prefill_reqs = [] - decode_reqs = [] - for seq_group_meta in seq_group_metadata_list: - if seq_group_meta.is_prompt: - prefill_reqs.append(seq_group_meta) - else: - decode_reqs.append(seq_group_meta) - # Prepare input tensors. ( input_tokens, input_positions, - prefill_attn_metadata, + attn_metadata, seq_lens, query_lens, - lora_index_mapping, - lora_prompt_mapping, + lora_mapping, lora_requests, multi_modal_input, slot_mapping, - ) = self._prepare_prompt(prefill_reqs) - ( - decode_input_tokens, - decode_input_positions, - decode_attn_metadata, - decode_lora_index_mapping, - decode_lora_prompt_mapping, - decode_lora_requests, - decode_slot_mapping, - ) = self._prepare_decode(decode_reqs) + num_prefill_tokens, + num_decode_tokens, + num_prefills, + ) = self._prepare_model_input(seq_group_metadata_list) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens, self.device, self.pin_memory) - if not self.scheduler_config.chunked_prefill_enabled: - assert (len(prefill_reqs) and len(decode_reqs)) == 0 - - num_prefills = len(seq_lens) - num_prefill_tokens = len(input_tokens) - num_decode_tokens = len(decode_input_tokens) - - # Coalesce tensors. Note that attn_metadata is currently not - # coalesced for simplicity. - input_tokens.extend(decode_input_tokens) - input_positions.extend(decode_input_positions) - slot_mapping.extend(decode_slot_mapping) - lora_index_mapping.extend(decode_lora_index_mapping) - lora_prompt_mapping.extend(decode_lora_prompt_mapping) - lora_requests.update(decode_lora_requests) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - - if self.lora_config: - lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) - else: - lora_mapping = None - - # Broadcast the metadata. - # If batch contains both prefill and decode, it sends 2 broadcasts. - # If it only contains 1 type, it triggers a single broadcast. - if (prefill_attn_metadata is not None - and decode_attn_metadata is not None): - batch_type = BatchType.MIXED - elif prefill_attn_metadata is not None: - batch_type = BatchType.PREFILL - else: - batch_type = BatchType.DECODE - metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, @@ -701,46 +617,24 @@ def prepare_input_tensors( "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, "num_prefills": num_prefills, - "batch_type": batch_type, } - if prefill_attn_metadata is not None: - metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) - else: - assert decode_attn_metadata is not None - metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) + if attn_metadata: + metadata_dict.update(attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) - - # Broadcast decode attn metadata for mixed batch type. - # The additional broadcast costs 300us overhead on 4 A10 GPUs. - # We can potentially reduce the overhead by coelescing tensors. - if batch_type == BatchType.MIXED: - assert decode_attn_metadata is not None - metadata_dict = decode_attn_metadata.asdict_zerocopy() - broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") - slot_mapping = metadata_dict.pop("slot_mapping") - num_prefills = metadata_dict.pop("num_prefills") selected_token_indices = metadata_dict.pop( "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_input = metadata_dict.pop("multi_modal_input") - num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") - num_decode_tokens = metadata_dict.pop("num_decode_tokens") - batch_type = metadata_dict.pop("batch_type") - - # Create an attention metadata. - prefill_attn_metadata = None - decode_attn_metadata = None - if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: - prefill_attn_metadata = self.attn_backend.make_metadata( + if metadata_dict: + attn_metadata = self.attn_backend.make_metadata( **metadata_dict) else: - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) + attn_metadata = None sampling_metadata = SamplingMetadata( seq_groups=None, selected_token_indices=selected_token_indices, @@ -748,22 +642,6 @@ def prepare_input_tensors( num_prompts=0, ) - # if it is a mixed batch, decode attn_metadata is broadcasted - # separately. - if batch_type == BatchType.MIXED: - metadata_dict = broadcast_tensor_dict(src=0) - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - - attn_metadata = AttentionMetadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - prefill_metadata=prefill_attn_metadata, - decode_metadata=decode_attn_metadata, - ) - return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_input) @@ -954,26 +832,22 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): # Create dummy attn_metadata. - decode_metadata = self.attn_backend.make_metadata( - is_prompt=False, + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping[:batch_size], seq_lens=None, seq_lens_tensor=seq_lens[:batch_size], max_query_len=None, - max_seq_len=self.max_seq_len_to_capture, - subquery_start_loc=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_seq_len_to_capture, + query_start_loc=None, seq_start_loc=None, context_lens_tensor=None, block_tables=block_tables[:batch_size], use_cuda_graph=True, ) - attn_metadata = AttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=slot_mapping[:batch_size], - prefill_metadata=None, - decode_metadata=decode_metadata, - ) if self.lora_config: lora_mapping = LoRAMapping( From e9cdd2b1e20beb1c21c55441d0e6a4ed86f4e292 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 15 May 2024 14:38:40 +0800 Subject: [PATCH 048/124] [CI/Build] Further decouple HuggingFace implementation from ours during tests (#4166) --- tests/conftest.py | 77 +++++++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 36 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 999ace2c3c699..c1a44a606e1bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,21 @@ import contextlib import gc import os -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import pytest import torch from PIL import Image -from transformers import (AutoModelForCausalLM, AutoProcessor, - LlavaForConditionalGeneration) +from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer, + LlavaConfig, LlavaForConditionalGeneration) from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.distributed import destroy_model_parallel +from vllm.logger import init_logger from vllm.sequence import MultiModalData -from vllm.transformers_utils.tokenizer import get_tokenizer + +logger = init_logger(__name__) _TEST_DIR = os.path.dirname(__file__) _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] @@ -129,9 +131,7 @@ def example_long_prompts() -> List[str]: "float": torch.float, } -_VISION_LANGUAGE_MODELS = { - "llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration, -} +AutoModelForCausalLM.register(LlavaConfig, LlavaForConditionalGeneration) _EMBEDDING_MODELS = [ "intfloat/e5-mistral-7b-instruct", @@ -143,23 +143,14 @@ class HfRunner: def __init__( self, model_name: str, - tokenizer_name: Optional[str] = None, dtype: str = "half", ) -> None: assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + self.model_name = model_name - if model_name in _VISION_LANGUAGE_MODELS: - self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ).cuda() - self.processor = AutoProcessor.from_pretrained( - model_name, - torch_dtype=torch_dtype, - ) - elif model_name in _EMBEDDING_MODELS: + + if model_name in _EMBEDDING_MODELS: # Lazy init required for AMD CI from sentence_transformers import SentenceTransformer self.model = SentenceTransformer( @@ -172,10 +163,24 @@ def __init__( torch_dtype=torch_dtype, trust_remote_code=True, ).cuda() - self.processor = None - if tokenizer_name is None: - tokenizer_name = model_name - self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + + try: + self.processor = AutoProcessor.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + except Exception: + logger.warning( + "Unable to auto-load processor from HuggingFace for " + "model %s. Using tokenizer instead.", model_name) + self.processor = self.tokenizer def generate( self, @@ -187,19 +192,19 @@ def generate( if images: assert len(prompts) == len(images) for i, prompt in enumerate(prompts): - if self.model_name not in _VISION_LANGUAGE_MODELS: - input_ids = self.tokenizer(prompt, - return_tensors="pt").input_ids - inputs = {"input_ids": input_ids.cuda()} - else: - image = images[i] if images else None - inputs = self.processor(text=prompt, - images=image, - return_tensors="pt") - inputs = { - key: value.cuda() if value is not None else None - for key, value in inputs.items() - } + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + inputs = { + key: value.cuda() if value is not None else None + for key, value in inputs.items() + } + output_ids = self.model.generate( **inputs, use_cache=True, From a5675d348b126e53928e139d1ed5b2c00a0044e8 Mon Sep 17 00:00:00 2001 From: zifeitong Date: Wed, 15 May 2024 07:22:09 -0700 Subject: [PATCH 049/124] [Bugfix] Properly set distributed_executor_backend in ParallelConfig (#4816) --- vllm/config.py | 1 + vllm/engine/arg_utils.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 26edd4567b9ac..2eb5bdd18d812 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -531,6 +531,7 @@ class ParallelConfig: If None, will use synchronous tokenization. ray_workers_use_nsight: Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. + placement_group: ray distributed model workers placement group. distributed_executor_backend: Backend to use for distributed model workers, either "ray" or "mp" (multiprocessing). If either pipeline_parallel_size or tensor_parallel_size is greater than 1, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bd44c2470182b..dab86b7c9eb35 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -548,14 +548,18 @@ def create_engine_config(self, ) -> EngineConfig: model_config.get_sliding_window(), self.enable_prefix_caching) parallel_config = ParallelConfig( - self.pipeline_parallel_size, self.tensor_parallel_size, - self.worker_use_ray, self.max_parallel_loading_workers, + self.pipeline_parallel_size, + self.tensor_parallel_size, + self.worker_use_ray, + self.max_parallel_loading_workers, self.disable_custom_all_reduce, TokenizerPoolConfig.create_config( self.tokenizer_pool_size, self.tokenizer_pool_type, self.tokenizer_pool_extra_config, - ), self.ray_workers_use_nsight) + ), + self.ray_workers_use_nsight, + distributed_executor_backend=self.distributed_executor_backend) speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, From 361c461a128a5df2faefeb70ffa98e61e4feda55 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 15 May 2024 11:38:49 -0700 Subject: [PATCH 050/124] [Doc] Highlight the fourth meetup in the README (#4842) --- README.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4fe5d9630f858..fc3f71b00c3c5 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,18 @@ Easy, fast, and cheap LLM serving for everyone

+--- + +**The Fourth vLLM Bay Area Meetup (June 11th 5:30pm-8pm PT)** + +We are thrilled to announce our fourth vLLM Meetup! +The vLLM team will share recent updates and roadmap. +We will also have vLLM collaborators from BentoML and Cloudflare coming up to the stage to discuss their experience in deploying LLMs with vLLM. +Please register [here](https://lu.ma/agivllm) and join us! + +--- + *Latest News* 🔥 -- [2024/05] We are hosting [the fourth vLLM meetup](https://lu.ma/event/manage/evt-A064fGpj52fviSn) with BentoML and Cloudflare on June 11! Please register [here](https://lu.ma/agivllm). - [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing). - [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing). - [2024/01] Added ROCm 6.0 support to vLLM. From fc0d9dfc3afcea2e23649ef8eb8bbe0446682813 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 16 May 2024 05:58:46 +0800 Subject: [PATCH 051/124] [Frontend] Re-enable custom roles in Chat Completions API (#4758) --- tests/entrypoints/test_openai_server.py | 30 +++++++++++ vllm/entrypoints/openai/protocol.py | 38 +++++++++++++- vllm/entrypoints/openai/serving_chat.py | 66 ++++++++++++++++--------- 3 files changed, 108 insertions(+), 26 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index ee2f034fd2c46..1b04e3205c4b8 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -783,6 +783,36 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI): assert content == "2" +async def test_custom_role(server, client: openai.AsyncOpenAI): + # Not sure how the model handles custom roles so we just check that + # both string and complex message content are handled in the same way + + resp1 = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "my-custom-role", + "content": "what is 1+1?", + }], # type: ignore + temperature=0, + seed=0) + + resp2 = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "my-custom-role", + "content": [{ + "type": "text", + "text": "what is 1+1?" + }] + }], # type: ignore + temperature=0, + seed=0) + + content1 = resp1.choices[0].message.content + content2 = resp2.choices[0].message.content + assert content1 == content2 + + async def test_guided_grammar(server, client: openai.AsyncOpenAI): simple_sql_grammar = """ start: select_statement diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 139c5716c7cea..35dfa09ac12ba 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -3,16 +3,50 @@ import time from typing import Any, Dict, List, Literal, Optional, Union +import openai.types.chat import torch -from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel, ConfigDict, Field, model_validator -from typing_extensions import Annotated +# pydantic needs the TypedDict from typing_extensions +from typing_extensions import Annotated, Required, TypedDict from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid +class CustomChatCompletionContentPartParam(TypedDict, total=False): + __pydantic_config__ = ConfigDict(extra="allow") # type: ignore + + type: Required[str] + """The type of the content part.""" + + +ChatCompletionContentPartParam = Union[ + openai.types.chat.ChatCompletionContentPartParam, + CustomChatCompletionContentPartParam] + + +class CustomChatCompletionMessageParam(TypedDict, total=False): + """Enables custom roles in the Chat Completion API.""" + role: Required[str] + """The role of the message's author.""" + + content: Union[str, List[ChatCompletionContentPartParam]] + """The contents of the message.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the + same role. + """ + + +ChatCompletionMessageParam = Union[ + openai.types.chat.ChatCompletionMessageParam, + CustomChatCompletionMessageParam] + + class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields model_config = ConfigDict(extra="forbid") diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1b469fc59b076..65824a2206be9 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,15 +1,16 @@ import codecs import time -from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List, - Optional, Tuple, TypedDict, Union, final) +from dataclasses import dataclass +from typing import (AsyncGenerator, AsyncIterator, Iterable, List, Optional, + TypedDict, Union, cast, final) from fastapi import Request -from openai.types.chat import (ChatCompletionContentPartParam, - ChatCompletionRole) +from openai.types.chat import ChatCompletionContentPartTextParam from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( + ChatCompletionContentPartParam, ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, @@ -31,6 +32,11 @@ class ConversationMessage(TypedDict): content: str +@dataclass(frozen=True) +class ChatMessageParseResult: + messages: List[ConversationMessage] + + class OpenAIServingChat(OpenAIServing): def __init__(self, @@ -77,27 +83,40 @@ def _load_chat_template(self, chat_template: Optional[str]): logger.warning( "No chat template provided. Chat API will not work.") - def _parse_chat_message_content( + def _parse_chat_message_content_parts( self, - role: ChatCompletionRole, - content: Optional[Union[str, - Iterable[ChatCompletionContentPartParam]]], - ) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]: - if content is None: - return [], [] - if isinstance(content, str): - return [ConversationMessage(role=role, content=content)], [] - + role: str, + parts: Iterable[ChatCompletionContentPartParam], + ) -> ChatMessageParseResult: texts: List[str] = [] - for _, part in enumerate(content): - if part["type"] == "text": - text = part["text"] + + for _, part in enumerate(parts): + part_type = part["type"] + if part_type == "text": + text = cast(ChatCompletionContentPartTextParam, part)["text"] texts.append(text) else: - raise NotImplementedError(f"Unknown part type: {part['type']}") + raise NotImplementedError(f"Unknown part type: {part_type}") + + messages = [ConversationMessage(role=role, content="\n".join(texts))] + + return ChatMessageParseResult(messages=messages) + + def _parse_chat_message_content( + self, + message: ChatCompletionMessageParam, + ) -> ChatMessageParseResult: + role = message["role"] + content = message.get("content") + + if content is None: + return ChatMessageParseResult(messages=[]) + if isinstance(content, str): + messages = [ConversationMessage(role=role, content=content)] + return ChatMessageParseResult(messages=messages) - return [ConversationMessage(role=role, content="\n".join(texts))], [] + return self._parse_chat_message_content_parts(role, content) async def create_chat_completion( self, request: ChatCompletionRequest, raw_request: Request @@ -119,11 +138,10 @@ async def create_chat_completion( try: conversation: List[ConversationMessage] = [] - for m in request.messages: - messages, _ = self._parse_chat_message_content( - m["role"], m["content"]) + for msg in request.messages: + parsed_msg = self._parse_chat_message_content(msg) - conversation.extend(messages) + conversation.extend(parsed_msg.messages) prompt = self.tokenizer.apply_chat_template( conversation=conversation, @@ -387,4 +405,4 @@ async def chat_completion_full_generator( usage=usage, ) - return response \ No newline at end of file + return response From 52f8107cf2e5b3cc1a6a4a96c22b24505f02df01 Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Wed, 15 May 2024 19:13:36 -0400 Subject: [PATCH 052/124] [Frontend] Support OpenAI batch file format (#4794) Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> --- examples/offline_inference_openai.md | 172 +++++++++++++++++++++ examples/openi_example_batch.jsonl | 2 + requirements-common.txt | 1 + tests/entrypoints/test_openai_run_batch.py | 53 +++++++ vllm/entrypoints/openai/protocol.py | 41 +++++ vllm/entrypoints/openai/run_batch.py | 141 +++++++++++++++++ vllm/entrypoints/openai/serving_chat.py | 8 +- 7 files changed, 415 insertions(+), 3 deletions(-) create mode 100644 examples/offline_inference_openai.md create mode 100644 examples/openi_example_batch.jsonl create mode 100644 tests/entrypoints/test_openai_run_batch.py create mode 100644 vllm/entrypoints/openai/run_batch.py diff --git a/examples/offline_inference_openai.md b/examples/offline_inference_openai.md new file mode 100644 index 0000000000000..40462ce1eb78c --- /dev/null +++ b/examples/offline_inference_openai.md @@ -0,0 +1,172 @@ +# Offline Inference with the OpenAI Batch file format + + **NOTE:** This is a guide to performing batch inference using the OpenAI batch file format, **NOT** the complete Batch (REST) API. + + ## File Format + + The OpenAI batch file format consists of a series of json objects on new lines. + + [See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/openai_example_batch.jsonl) + + Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details. + + **NOTE:** We currently only support to `/v1/chat/completions` endpoint (embeddings and completions coming soon). + + ## Pre-requisites + +* Ensure you are using `vllm >= 0.4.3`. You can check by running `python -c "import vllm; print(vllm.__version__)"`. +* The examples in this document use `meta-llama/Meta-Llama-3-8B-Instruct`. + - Create a [user access token](https://huggingface.co/docs/hub/en/security-tokens) + - Install the token on your machine (Run `huggingface-cli login`). + - Get access to the gated model by [visiting the model card](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and agreeing to the terms and conditions. + + + ## Example: Running with a local file + + ### Step 1: Create your batch file + + To follow along with this example, you can download the example batch, or create your own batch file in your working directory. + + ``` + wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl + ``` + + Once you've created your batch file it should look like this + + ``` + $ cat openai_example_batch.jsonl +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} + ``` + + ### Step 2: Run the batch + +The batch running tool is designed to be used from the command line. + +You can run the batch with the following command, which will write its results to a file called `results.jsonl` + +``` +python -m vllm.entrypoints.openai.run_batch -i openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct +``` + +### Step 3: Check your results + +You should now have your results at `results.jsonl`. You can check your results by running `cat results.jsonl` + +``` +$ cat ../results.jsonl +{"id":"vllm-383d1c59835645aeb2e07d004d62a826","custom_id":"request-1","response":{"id":"cmpl-61c020e54b964d5a98fa7527bfcdd378","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Hello! It's great to meet you! I'm here to help with any questions or tasks you may have. What's on your mind today?"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":25,"total_tokens":56,"completion_tokens":31}},"error":null} +{"id":"vllm-42e3d09b14b04568afa3f1797751a267","custom_id":"request-2","response":{"id":"cmpl-f44d049f6b3a42d4b2d7850bb1e31bcc","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"*silence*"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":27,"total_tokens":32,"completion_tokens":5}},"error":null} +``` + +## Example 2: Using remote files + +The batch runner supports remote input and output urls that are accessible via http/https. + +For example, to run against our example input file located at `https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl`, you can run + +``` +python -m vllm.entrypoints.openai.run_batch -i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct +``` + +## Example 3: Integrating with AWS S3 + +To integrate with cloud blob storage, we recommend using presigned urls. + +[Learn more about S3 presigned urls here] + +### Additional prerequisites + +* [Create an S3 bucket](https://docs.aws.amazon.com/AmazonS3/latest/userguide/creating-bucket.html). +* The `awscli` package (Run `pip install awscli`) to configure your credentials and interactively use s3. + - [Configure your credentials](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-quickstart.html). +* The `boto3` python package (Run `pip install boto3`) to generate presigned urls. + +### Step 1: Upload your input script + +To follow along with this example, you can download the example batch, or create your own batch file in your working directory. + + ``` + wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl + ``` + + Once you've created your batch file it should look like this + + ``` + $ cat openai_example_batch.jsonl +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} + ``` + +Now upload your batch file to your S3 bucket. + +``` +aws s3 cp openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl +``` + + +### Step 2: Generate your presigned urls + +Presigned put urls can only be generated via the SDK. You can run the following python script to generate your presigned urls. Be sure to replace the `MY_BUCKET`, `MY_INPUT_FILE.jsonl`, and `MY_OUTPUT_FILE.jsonl` placeholders with your bucket and file names. + +(The script is adapted from https://github.com/awsdocs/aws-doc-sdk-examples/blob/main/python/example_code/s3/s3_basics/presigned_url.py) + +``` +import boto3 +from botocore.exceptions import ClientError + +def generate_presigned_url(s3_client, client_method, method_parameters, expires_in): + """ + Generate a presigned Amazon S3 URL that can be used to perform an action. + + :param s3_client: A Boto3 Amazon S3 client. + :param client_method: The name of the client method that the URL performs. + :param method_parameters: The parameters of the specified client method. + :param expires_in: The number of seconds the presigned URL is valid for. + :return: The presigned URL. + """ + try: + url = s3_client.generate_presigned_url( + ClientMethod=client_method, Params=method_parameters, ExpiresIn=expires_in + ) + except ClientError: + raise + return url + + +s3_client = boto3.client("s3") +input_url = generate_presigned_url( + s3_client, "get_object", {"Bucket": "MY_BUCKET", "Key": "MY_INPUT_FILE.jsonl"}, 3600 +) +output_url = generate_presigned_url( + s3_client, "put_object", {"Bucket": "MY_BUCKET", "Key": "MY_OUTPUT_FILE.jsonl"}, 3600 +) +print(f"{input_url=}") +print(f"{output_url=}") +``` + +This script should output + +``` +input_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091' +output_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091' +``` + +### Step 3: Run the batch runner using your presigned urls + +You can now run the batch runner, using the urls generated in the previous section. + +``` +python -m vllm.entrypoints.openai.run_batch \ + -i "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \ + -o "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \ + --model --model meta-llama/Meta-Llama-3-8B-Instruct +``` + +### Step 4: View your results + +Your results are now on S3. You can view them in your terminal by running + +``` +aws s3 cp s3://MY_BUCKET/MY_OUTPUT_FILE.jsonl - +``` diff --git a/examples/openi_example_batch.jsonl b/examples/openi_example_batch.jsonl new file mode 100644 index 0000000000000..5aa7e185c180a --- /dev/null +++ b/examples/openi_example_batch.jsonl @@ -0,0 +1,2 @@ +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} diff --git a/requirements-common.txt b/requirements-common.txt index bd779d5acb68e..cc4b15d877d0f 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -8,6 +8,7 @@ py-cpuinfo transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3. tokenizers >= 0.19.1 # Required for Llama 3. fastapi +aiohttp openai uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. diff --git a/tests/entrypoints/test_openai_run_batch.py b/tests/entrypoints/test_openai_run_batch.py new file mode 100644 index 0000000000000..5de28513ca391 --- /dev/null +++ b/tests/entrypoints/test_openai_run_batch.py @@ -0,0 +1,53 @@ +import subprocess +import sys +import tempfile + +from vllm.entrypoints.openai.protocol import BatchRequestOutput + +# ruff: noqa: E501 +INPUT_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}""" + +INVALID_INPUT_BATCH = """{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}""" + + +def test_e2e(): + with tempfile.NamedTemporaryFile( + "w") as input_file, tempfile.NamedTemporaryFile( + "r") as output_file: + input_file.write(INPUT_BATCH) + input_file.flush() + proc = subprocess.Popen([ + sys.executable, "-m", "vllm.entrypoints.openai.run_batch", "-i", + input_file.name, "-o", output_file.name, "--model", + "NousResearch/Meta-Llama-3-8B-Instruct" + ], ) + proc.communicate() + proc.wait() + assert proc.returncode == 0, f"{proc=}" + + contents = output_file.read() + for line in contents.strip().split("\n"): + # Ensure that the output format conforms to the openai api. + # Validation should throw if the schema is wrong. + BatchRequestOutput.model_validate_json(line) + + +def test_e2e_invalid_input(): + """ + Ensure that we fail when the input doesn't conform to the openai api. + """ + with tempfile.NamedTemporaryFile( + "w") as input_file, tempfile.NamedTemporaryFile( + "r") as output_file: + input_file.write(INVALID_INPUT_BATCH) + input_file.flush() + proc = subprocess.Popen([ + sys.executable, "-m", "vllm.entrypoints.openai.run_batch", "-i", + input_file.name, "-o", output_file.name, "--model", + "NousResearch/Meta-Llama-3-8B-Instruct" + ], ) + proc.communicate() + proc.wait() + assert proc.returncode != 0, f"{proc=}" diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 35dfa09ac12ba..41e2f77fe56f1 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -526,3 +526,44 @@ class ChatCompletionStreamResponse(OpenAIBaseModel): model: str choices: List[ChatCompletionResponseStreamChoice] usage: Optional[UsageInfo] = Field(default=None) + + +class BatchRequestInput(OpenAIBaseModel): + """ + The per-line object of the batch input file. + + NOTE: Currently only the `/v1/chat/completions` endpoint is supported. + """ + + # A developer-provided per-request id that will be used to match outputs to + # inputs. Must be unique for each request in a batch. + custom_id: str + + # The HTTP method to be used for the request. Currently only POST is + # supported. + method: str + + # The OpenAI API relative URL to be used for the request. Currently + # /v1/chat/completions is supported. + url: str + + # The parameteters of the request. + body: Union[ChatCompletionRequest, ] + + +class BatchRequestOutput(OpenAIBaseModel): + """ + The per-line object of the batch output and error files + """ + + id: str + + # A developer-provided per-request id that will be used to match outputs to + # inputs. + custom_id: str + + response: Optional[ChatCompletionResponse] + + # For requests that failed with a non-HTTP error, this will contain more + # information on the cause of the failure. + error: Optional[Any] diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py new file mode 100644 index 0000000000000..99f1b2d6d091b --- /dev/null +++ b/vllm/entrypoints/openai/run_batch.py @@ -0,0 +1,141 @@ +import argparse +import asyncio +import sys +from io import StringIO + +import aiohttp + +import vllm +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import (BatchRequestInput, + BatchRequestOutput, + ChatCompletionResponse) +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="vLLM OpenAI-Compatible batch runner.") + parser.add_argument( + "-i", + "--input-file", + required=True, + type=str, + help= + "The path or url to a single input file. Currently supports local file " + "paths, or the http protocol (http or https). If a URL is specified, " + "the file should be available via HTTP GET.") + parser.add_argument( + "-o", + "--output-file", + required=True, + type=str, + help="The path or url to a single output file. Currently supports " + "local file paths, or web (http or https) urls. If a URL is specified," + " the file should be available via HTTP PUT.") + parser.add_argument("--response-role", + type=nullable_str, + default="assistant", + help="The role name to return if " + "`request.add_generation_prompt=true`.") + + parser = AsyncEngineArgs.add_cli_args(parser) + return parser.parse_args() + + +async def read_file(path_or_url: str) -> str: + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + async with aiohttp.ClientSession() as session, \ + session.get(path_or_url) as resp: + return await resp.text() + else: + with open(path_or_url, "r") as f: + return f.read() + + +async def write_file(path_or_url: str, data: str) -> None: + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + async with aiohttp.ClientSession() as session, \ + session.put(path_or_url, data=data.encode("utf-8")): + pass + else: + # We should make this async, but as long as this is always run as a + # standalone program, blocking the event loop won't effect performance + # in this particular case. + with open(path_or_url, "w") as f: + f.write(data) + + +async def run_request(chat_serving: OpenAIServingChat, + request: BatchRequestInput) -> BatchRequestOutput: + chat_request = request.body + chat_response = await chat_serving.create_chat_completion(chat_request) + if isinstance(chat_response, ChatCompletionResponse): + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=chat_response, + error=None, + ) + else: + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=None, + error=chat_response, + ) + return batch_output + + +async def main(args): + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_API_SERVER) + + # When using single vLLM without engine_use_ray + model_config = await engine.get_model_config() + + openai_serving_chat = OpenAIServingChat( + engine, + model_config, + served_model_names, + args.response_role, + ) + + # Submit all requests in the file to the engine "concurrently". + response_futures = [] + for request_json in (await read_file(args.input_file)).strip().split("\n"): + request = BatchRequestInput.model_validate_json(request_json) + response_futures.append(run_request(openai_serving_chat, request)) + + responses = await asyncio.gather(*response_futures) + + output_buffer = StringIO() + for response in responses: + print(response.model_dump_json(), file=output_buffer) + + output_buffer.seek(0) + await write_file(args.output_file, output_buffer.read().strip()) + + # Temporary workaround for https://github.com/vllm-project/vllm/issues/4789 + sys.exit(0) + + +if __name__ == "__main__": + args = parse_args() + + logger.info("vLLM API server version %s", vllm.__version__) + logger.info("args: %s", args) + + asyncio.run(main(args)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 65824a2206be9..c86e41c601be0 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -119,7 +119,9 @@ def _parse_chat_message_content( return self._parse_chat_message_content_parts(role, content) async def create_chat_completion( - self, request: ChatCompletionRequest, raw_request: Request + self, + request: ChatCompletionRequest, + raw_request: Optional[Request] = None ) -> Union[ErrorResponse, AsyncGenerator[str, None], ChatCompletionResponse]: """Completion API similar to OpenAI's API. @@ -337,7 +339,7 @@ async def chat_completion_stream_generator( yield "data: [DONE]\n\n" async def chat_completion_full_generator( - self, request: ChatCompletionRequest, raw_request: Request, + self, request: ChatCompletionRequest, raw_request: Optional[Request], result_generator: AsyncIterator[RequestOutput], request_id: str, conversation: List[ConversationMessage] ) -> Union[ErrorResponse, ChatCompletionResponse]: @@ -347,7 +349,7 @@ async def chat_completion_full_generator( final_res: Optional[RequestOutput] = None async for res in result_generator: - if await raw_request.is_disconnected(): + if raw_request is not None and await raw_request.is_disconnected(): # Abort the request if the client disconnects. await self.engine.abort(request_id) return self.create_error_response("Client disconnected") From 30e754390c2a8a7198f472386d35ee1ec9443e4a Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 16 May 2024 01:11:54 -0400 Subject: [PATCH 053/124] [Core] Implement sharded state loader (#4690) Co-authored-by: Woosuk Kwon --- examples/save_sharded_state.py | 75 +++++++++++ tests/test_sharded_state_loader.py | 90 +++++++++++++ vllm/config.py | 1 + vllm/executor/distributed_gpu_executor.py | 11 ++ vllm/model_executor/model_loader/loader.py | 148 +++++++++++++++++++++ vllm/worker/model_runner.py | 14 ++ vllm/worker/worker.py | 12 ++ 7 files changed, 351 insertions(+) create mode 100644 examples/save_sharded_state.py create mode 100644 tests/test_sharded_state_loader.py diff --git a/examples/save_sharded_state.py b/examples/save_sharded_state.py new file mode 100644 index 0000000000000..c595d98ba2750 --- /dev/null +++ b/examples/save_sharded_state.py @@ -0,0 +1,75 @@ +""" +Saves each worker's model state dict directly to a checkpoint, which enables a +fast load path for large tensor-parallel models where each worker only needs to +read its own shard rather than the entire checkpoint. + +Example usage: + +python save_sharded_state.py \ + --model /path/to/load \ + --quantization deepspeedfp \ + --tensor-parallel-size 8 \ + --output /path/to/save + +Then, the model can be loaded with + +llm = LLM( + model="/path/to/save", + load_format="sharded_state", + quantization="deepspeedfp", + tensor_parallel_size=8, +) +""" +import argparse +import dataclasses +import os +import shutil +from pathlib import Path + +from vllm import LLM, EngineArgs + +parser = argparse.ArgumentParser() +EngineArgs.add_cli_args(parser) +parser.add_argument("--output", + "-o", + required=True, + type=str, + help="path to output checkpoint") +parser.add_argument("--file-pattern", + type=str, + help="string pattern of saved filenames") +parser.add_argument("--max-file-size", + type=str, + default=5 * 1024**3, + help="max size (in bytes) of each safetensors file") + + +def main(args): + engine_args = EngineArgs.from_cli_args(args) + if engine_args.enable_lora: + raise ValueError("Saving with enable_lora=True is not supported!") + model_path = engine_args.model + if not Path(model_path).is_dir(): + raise ValueError("model path must be a local directory") + # Create LLM instance from arguments + llm = LLM(**dataclasses.asdict(engine_args)) + # Prepare output directory + Path(args.output).mkdir(exist_ok=True) + # Dump worker states to output directory + model_executor = llm.llm_engine.model_executor + model_executor.save_sharded_state(path=args.output, + pattern=args.file_pattern, + max_size=args.max_file_size) + # Copy metadata files to output directory + for file in os.listdir(model_path): + if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"): + if os.path.isdir(os.path.join(model_path, file)): + shutil.copytree(os.path.join(model_path, file), + os.path.join(args.output, file)) + else: + shutil.copy(os.path.join(model_path, file), args.output) + + +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py new file mode 100644 index 0000000000000..8540e98da366a --- /dev/null +++ b/tests/test_sharded_state_loader.py @@ -0,0 +1,90 @@ +import os +import shutil +from tempfile import TemporaryDirectory + +import pytest +import torch +from huggingface_hub import snapshot_download + +from vllm import LLM, SamplingParams +from vllm.model_executor.model_loader.loader import ShardedStateLoader + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create a sampling params object. +sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + seed=0, + max_tokens=256, + ignore_eos=True, +) + + +def test_filter_subtensors(): + state_dict = { + "a": torch.empty(2), + "b": torch.empty((2, 4)), + "c": torch.empty((2, 4, 8)), + } + state_dict.update({ + "x": state_dict["b"], + "y": state_dict["c"][1, 2, :], + "z": state_dict["c"][1, :, 4], + }) + filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict) + assert tuple(filtered_state_dict.keys()) == ("a", "b", "c") + for key, tensor in filtered_state_dict.items(): + assert tensor.equal(state_dict[key]) + + +@pytest.mark.parametrize("enable_lora", [False, True]) +def test_sharded_state_loader(enable_lora): + weights_patterns = ("*.bin", "*.pt", "*.safetensors") + + with TemporaryDirectory() as cache_dir, TemporaryDirectory() as output_dir: + input_dir = snapshot_download("meta-llama/Llama-2-7b-hf", + cache_dir=cache_dir) + + llm = LLM( + model=input_dir, + worker_use_ray=True, + gpu_memory_utilization=0.3, + ) + + # Dump worker states to output directory + model_executor = llm.llm_engine.model_executor + model_executor.save_sharded_state(path=output_dir) + # Copy metadata files to output directory + for file in os.listdir(input_dir): + if not any(file.endswith(ext) for ext in weights_patterns): + shutil.copy(f"{input_dir}/{file}", output_dir) + del llm.llm_engine.model_executor + + llm_before = LLM( + model=input_dir, + worker_use_ray=True, + enable_lora=enable_lora, + gpu_memory_utilization=0.3, + ) + gen_before = llm_before.generate(prompts, sampling_params) + out_before = [gen.outputs[0].__dict__ for gen in gen_before] + del llm_before.llm_engine.model_executor + + llm_after = LLM( + model=output_dir, + worker_use_ray=True, + enable_lora=enable_lora, + gpu_memory_utilization=0.3, + load_format="sharded_state", + ) + gen_after = llm_after.generate(prompts, sampling_params) + out_after = [gen.outputs[0].__dict__ for gen in gen_after] + del llm_after.llm_engine.model_executor + + assert out_before == out_after diff --git a/vllm/config.py b/vllm/config.py index 2eb5bdd18d812..91f590aaf79eb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -463,6 +463,7 @@ class LoadFormat(str, enum.Enum): NPCACHE = "npcache" DUMMY = "dummy" TENSORIZER = "tensorizer" + SHARDED_STATE = "sharded_state" @dataclass diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 4c922ef63ee04..c5b1e61112afb 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -77,6 +77,17 @@ def remove_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self._run_workers("list_loras") + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self._run_workers("save_sharded_state", + path=path, + pattern=pattern, + max_size=max_size) + @abstractmethod def _run_workers( self, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index b14824a359b6d..dc568928b2859 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1,4 +1,5 @@ # ruff: noqa: SIM117 +import collections import copy import glob import os @@ -366,6 +367,150 @@ def load_model(self, *, model_config: ModelConfig, cache_config) +class ShardedStateLoader(BaseModelLoader): + """ + Model loader that directly loads each worker's model state dict, which + enables a fast load path for large tensor-parallel models where each worker + only needs to read its own shard rather than the entire checkpoint. See + `examples/save_sharded_states.py` for creating a sharded checkpoint. + """ + + DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + extra_config = ({} if load_config.model_loader_extra_config is None + else load_config.model_loader_extra_config.copy()) + self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) + if extra_config: + raise ValueError(f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{load_config.model_loader_extra_config.keys()}") + + @staticmethod + def _filter_subtensors( + tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Filter out all tensors that share the same memory or a subset of the + memory of another tensor. + """ + same_storage_groups = collections.defaultdict(list) + for key, tensor in tensors.items(): + if tensor.numel(): + ptr = tensor.untyped_storage().data_ptr() + same_storage_groups[tensor.device, ptr].append((key, tensor)) + + def get_end_ptr(tensor: torch.Tensor) -> int: + return tensor.view(-1)[-1].data_ptr() + tensor.element_size() + + result = {} + for group in same_storage_groups.values(): + for k, t in group: + a, b = t.data_ptr(), get_end_ptr(t) + for k2, t2 in group: + if not t2.is_contiguous(): + continue + a2, b2 = t2.data_ptr(), get_end_ptr(t2) + if a < a2 or b2 < b: + continue + if a2 < a or b < b2 or not t.is_contiguous(): + break # t2 covers strictly more memory than t. + if k2 < k: + # Same tensors, keep the one with the smaller key. + break + else: + result[k] = t + return result + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + from safetensors.torch import safe_open + + from vllm.distributed import get_tensor_model_parallel_rank + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config, + cache_config) + rank = get_tensor_model_parallel_rank() + pattern = os.path.join( + model_config.model, + self.pattern.format(rank=rank, part="*"), + ) + filepaths = glob.glob(pattern) + if not filepaths: + # TODO: support un-sharded checkpoints too + raise ValueError( + f"Could not find checkpoint files '{pattern}', only " + f"pre-sharded checkpoints are currently supported!") + state_dict = self._filter_subtensors(model.state_dict()) + for path in filepaths: + with safe_open(path, framework="pt") as f: + for key in f.keys(): # noqa: SIM118 + tensor = f.get_tensor(key) + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape + for dim, size in enumerate(tensor.shape): + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning( + "loading tensor of shape %s into " + "parameter '%s' of shape %s", tensor.shape, + key, param_shape) + param_data.copy_(tensor) + state_dict.pop(key) + if state_dict: + raise ValueError( + f"Missing keys {tuple(state_dict)} in loaded state!") + return model.eval() + + @staticmethod + def save_model( + model: torch.nn.Module, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from safetensors.torch import save_file + + from vllm.distributed import get_tensor_model_parallel_rank + if pattern is None: + pattern = ShardedStateLoader.DEFAULT_PATTERN + rank = get_tensor_model_parallel_rank() + part_idx = 0 + total_size = 0 + state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) + state_dict_part: Dict[str, torch.Tensor] = {} + for key, tensor in state_dict.items(): + param_size = tensor.nelement() * tensor.element_size() + if max_size is not None and total_size + param_size > max_size: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + part_idx += 1 + total_size = 0 + state_dict_part = {} + state_dict_part[key] = tensor + total_size += param_size + if len(state_dict_part) > 0: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + + def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: """Get a model loader based on the load format.""" @@ -378,4 +523,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.TENSORIZER: return TensorizerLoader(load_config) + if load_config.load_format == LoadFormat.SHARDED_STATE: + return ShardedStateLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index dcdd4b962454e..623a0bc32211c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -182,6 +182,20 @@ def load_model(self) -> None: "but the KV cache data type is not FP8. " "KV cache scaling factors will not be used.") + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from vllm.model_executor.model_loader.loader import ShardedStateLoader + ShardedStateLoader.save_model( + self.model, + path, + pattern=pattern, + max_size=max_size, + ) + def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 82cf58101a95b..faea50fbfbf50 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -119,6 +119,18 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self.model_runner.save_sharded_state( + path, + pattern=pattern, + max_size=max_size, + ) + @torch.inference_mode() def determine_num_available_blocks(self) -> Tuple[int, int]: """Profiles the peak memory usage of the model to determine how many From 973617ae02a4e8e6190674cf1cdb0c0803b65ae6 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 16 May 2024 00:53:51 -0700 Subject: [PATCH 054/124] [Speculative decoding][Re-take] Enable TP>1 speculative decoding (#4840) Co-authored-by: Cade Daniel Co-authored-by: Cade Daniel --- .buildkite/test-pipeline.yaml | 1 + benchmarks/benchmark_latency.py | 6 + tests/spec_decode/e2e/test_compatibility.py | 50 ----- tests/spec_decode/e2e/test_integration.py | 44 +++++ .../spec_decode/e2e/test_integration_dist.py | 65 +++++++ .../e2e/test_multistep_correctness.py | 37 ---- vllm/distributed/communication_op.py | 10 +- vllm/executor/gpu_executor.py | 71 ++------ vllm/executor/ray_gpu_executor.py | 19 +- vllm/spec_decode/spec_decode_worker.py | 171 +++++++++++++++--- vllm/worker/worker.py | 3 +- vllm/worker/worker_base.py | 2 +- 12 files changed, 297 insertions(+), 182 deletions(-) create mode 100644 tests/spec_decode/e2e/test_integration.py create mode 100644 tests/spec_decode/e2e/test_integration_dist.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index eeb6eaa2165bc..aa74672f4bf67 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -42,6 +42,7 @@ steps: - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py + - pytest -v -s spec_decode/e2e/test_integration_dist.py - label: Distributed Tests (Multiple Groups) working_dir: "/vllm-workspace/tests" diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 44da3bad8d840..8f3168c115ae6 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -18,6 +18,8 @@ def main(args: argparse.Namespace): # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. llm = LLM(model=args.model, + speculative_model=args.speculative_model, + num_speculative_tokens=args.num_speculative_tokens, tokenizer=args.tokenizer, quantization=args.quantization, tensor_parallel_size=args.tensor_parallel_size, @@ -28,6 +30,7 @@ def main(args: argparse.Namespace): quantization_param_path=args.quantization_param_path, device=args.device, ray_workers_use_nsight=args.ray_workers_use_nsight, + use_v2_block_manager=args.use_v2_block_manager, enable_chunked_prefill=args.enable_chunked_prefill, download_dir=args.download_dir, block_size=args.block_size) @@ -99,6 +102,8 @@ def run_to_completion(profile_dir: Optional[str] = None): description='Benchmark the latency of processing a single batch of ' 'requests till completion.') parser.add_argument('--model', type=str, default='facebook/opt-125m') + parser.add_argument('--speculative-model', type=str, default=None) + parser.add_argument('--num-speculative-tokens', type=int, default=None) parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--quantization', '-q', @@ -181,6 +186,7 @@ def run_to_completion(profile_dir: Optional[str] = None): action='store_true', help='If True, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') + parser.add_argument('--use-v2-block-manager', action='store_true') parser.add_argument( "--ray-workers-use-nsight", action='store_true', diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py index 60c20ed7db7a3..81f91c5e10b0d 100644 --- a/tests/spec_decode/e2e/test_compatibility.py +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -5,56 +5,6 @@ from .conftest import get_output_from_llm_generator -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model": "JackFram/llama-68m", - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - - # Required for spec decode. - "use_v2_block_manager": True - }]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [ - { - # Expect failure as spec decode not supported by - # Ray backend. - "worker_use_ray": True, - }, - ]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_xfail_ray(test_llm_generator): - """Verify that speculative decoding with Ray fails. - """ - output_len = 128 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - ] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - try: - with pytest.raises( - AssertionError, - match="Speculative decoding not yet supported for "): - get_output_from_llm_generator(test_llm_generator, prompts, - sampling_params) - finally: - # we need to free up ray resource, - # so that latter test could use the gpu we allocated here - import ray - ray.shutdown() - - @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_integration.py b/tests/spec_decode/e2e/test_integration.py new file mode 100644 index 0000000000000..4a2b62151f8cd --- /dev/null +++ b/tests/spec_decode/e2e/test_integration.py @@ -0,0 +1,44 @@ +"""Tests which cover integration of the speculative decoding framework with +other features, e.g. cuda graphs. +""" + +import pytest + +from .conftest import run_greedy_equality_correctness_test + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Required for spec decode. + "use_v2_block_manager": True, + + # Verify equality when cuda graphs allowed. + "enforce_eager": False, + "model": "JackFram/llama-68m", + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + # Identical models. + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("output_len", [32]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator, + batch_size, output_len): + """Verify spec decode equality when cuda graphs are enabled. + """ + run_greedy_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + ) diff --git a/tests/spec_decode/e2e/test_integration_dist.py b/tests/spec_decode/e2e/test_integration_dist.py new file mode 100644 index 0000000000000..d444ef24cbfda --- /dev/null +++ b/tests/spec_decode/e2e/test_integration_dist.py @@ -0,0 +1,65 @@ +"""Tests which cover integration of the speculative decoding framework with +tensor parallelism. +""" + +import pytest +import torch + +from vllm.utils import is_hip + +from .conftest import run_greedy_equality_correctness_test + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "tensor_parallel_size": 2, + + # Use AsyncLLM engine, so that the engine runs in its own process. + # Otherwise, since vLLM does not follow true SPMD, the test runner + # process will have both the engine and the rank0 worker. NCCL is not + # cleaned up properly, and its server host thread leaks, causing the + # second run of the test to fail with internal NCCL error. + "use_async": True, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + }, + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }, +]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality when tensor parallelism is used. + """ + if is_hip(): + pytest.skip("hip is not well-supported yet") + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index d2da039e84c07..94d71fb012727 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -611,40 +611,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, batch_size, max_output_len=output_len, force_output_len=True) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Required for spec decode. - "use_v2_block_manager": True, - - # Verify equality when cuda graphs allowed. - "enforce_eager": False, - "model": "JackFram/llama-68m", - }]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [ - { - # Identical models. - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - ]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize("output_len", [32]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator, - batch_size, output_len): - """Verify spec decode equality when cuda graphs are enabled. - """ - run_greedy_equality_correctness_test( - baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len=output_len, - force_output_len=True, - ) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 9cc776f8324f2..f8ee0f9796bcd 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -219,16 +219,16 @@ def broadcast_tensor_dict( to broadcast the metadata of the dict (e.g. dict structure, tensor sizes, dtypes). """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() + or torch.distributed.get_world_size(group=group) == 1): + return tensor_dict + group = group or torch.distributed.group.WORLD metadata_group = metadata_group or get_cpu_world_group() ranks = torch.distributed.get_process_group_ranks(group) assert src in ranks, f"Invalid src rank ({src})" - # Bypass the function if we are using only 1 GPU. - world_size = torch.distributed.get_world_size(group=group) - if world_size == 1: - return tensor_dict - rank = torch.distributed.get_rank() if rank == src: metadata_list: List[Tuple[Any, Any]] = [] diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 2b72b31b5f070..3ad201f4757ec 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase): def _init_executor(self) -> None: """Initialize the worker and load the model. - - If speculative decoding is enabled, we instead create the speculative - worker. """ - if self.speculative_config is None: - self._init_non_spec_worker() - else: - self._init_spec_worker() + assert self.parallel_config.world_size == 1, ( + "GPUExecutor only supports single GPU.") + + self.driver_worker = self._create_worker() + self.driver_worker.init_device() + self.driver_worker.load_model() def _get_worker_kwargs( self, @@ -45,6 +44,7 @@ def _get_worker_kwargs( distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, + speculative_config=self.speculative_config, is_driver_worker=rank == 0, ) @@ -52,59 +52,22 @@ def _create_worker(self, local_rank: int = 0, rank: int = 0, distributed_init_method: Optional[str] = None): + + if self.speculative_config is None: + worker_module_name = "vllm.worker.worker" + worker_class_name = "Worker" + else: + worker_module_name = "vllm.spec_decode.spec_decode_worker" + worker_class_name = "create_spec_worker" + wrapper = WorkerWrapperBase( - worker_module_name="vllm.worker.worker", - worker_class_name="Worker", + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, ) wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, distributed_init_method)) return wrapper.worker - def _init_non_spec_worker(self): - assert self.parallel_config.world_size == 1, ( - "GPUExecutor only supports single GPU.") - - self.driver_worker = self._create_worker() - self.driver_worker.init_device() - self.driver_worker.load_model() - - def _init_spec_worker(self): - """Initialize a SpecDecodeWorker, using a draft model for proposals. - """ - assert self.speculative_config is not None - - from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker - - target_worker = self._create_worker() - - draft_worker_kwargs = self._get_worker_kwargs() - # Override draft-model specific worker args. - draft_worker_kwargs.update( - model_config=self.speculative_config.draft_model_config, - parallel_config=self.speculative_config.draft_parallel_config, - ngram_prompt_lookup_max=self.speculative_config. - ngram_prompt_lookup_max, - ngram_prompt_lookup_min=self.speculative_config. - ngram_prompt_lookup_min, - # TODO allow draft-model specific load config. - #load_config=self.load_config, - ) - - spec_decode_worker = SpecDecodeWorker.create_worker( - scorer_worker=target_worker, - draft_worker_kwargs=draft_worker_kwargs, - disable_by_batch_size=self.speculative_config. - speculative_disable_by_batch_size, - ) - - assert self.parallel_config.world_size == 1, ( - "GPUExecutor only supports single GPU.") - - self.driver_worker = spec_decode_worker - - # Load model handled in spec decode worker. - self.driver_worker.init_device() - def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 9cb03ec8c3f5a..dd3ee60682d30 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -28,9 +28,6 @@ class RayGPUExecutor(DistributedGPUExecutor): def _init_executor(self) -> None: - assert (not self.speculative_config - ), "Speculative decoding not yet supported for RayGPU backend." - assert self.parallel_config.distributed_executor_backend == "ray" placement_group = self.parallel_config.placement_group @@ -90,14 +87,22 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", placement_group_capture_child_tasks=True, placement_group_bundle_index=bundle_id, ) + + if self.speculative_config is not None: + worker_module_name = "vllm.spec_decode.spec_decode_worker" + worker_class_name = "create_spec_worker" + else: + worker_module_name = "vllm.worker.worker" + worker_class_name = "Worker" + worker = ray.remote( num_cpus=0, num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, )(RayWorkerWrapper).remote( - worker_module_name="vllm.worker.worker", - worker_class_name="Worker", + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, trust_remote_code=self.model_config.trust_remote_code, ) @@ -107,8 +112,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - worker_module_name="vllm.worker.worker", - worker_class_name="Worker", + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, trust_remote_code=self.model_config.trust_remote_code, ) else: diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index a4e759095b294..ef17b8c1e2cc0 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -3,6 +3,7 @@ import torch +from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.sequence import (ExecuteModelRequest, SamplerOutput, @@ -17,11 +18,43 @@ get_all_num_logprobs, get_all_seq_ids, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) +from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase logger = init_logger(__name__) +def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": + """Helper method that is the entrypoint for Executors which use + WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. + """ + assert "speculative_config" in kwargs + speculative_config = kwargs.get("speculative_config") + assert speculative_config is not None + + target_worker = Worker(*args, **kwargs) + + draft_worker_kwargs = kwargs.copy() + # Override draft-model specific worker args. + draft_worker_kwargs.update( + model_config=speculative_config.draft_model_config, + parallel_config=speculative_config.draft_parallel_config, + ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max, + ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min, + # TODO allow draft-model specific load config. + #load_config=load_config, + ) + + spec_decode_worker = SpecDecodeWorker.create_worker( + scorer_worker=target_worker, + draft_worker_kwargs=draft_worker_kwargs, + disable_by_batch_size=speculative_config. + speculative_disable_by_batch_size, + ) + + return spec_decode_worker + + class SpecDecodeWorker(LoraNotSupportedWorkerBase): """Worker which implements speculative decoding. @@ -142,6 +175,9 @@ def init_device(self) -> None: self._configure_model_sampler_for_spec_decode() + def load_model(self, *args, **kwargs): + pass + def _configure_model_sampler_for_spec_decode(self): """Configure model sampler to emit GPU tensors. This allows spec decode to keep data on device without transferring to CPU and serializing, @@ -195,39 +231,97 @@ def initialize_cache(self, num_gpu_blocks: int, self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) + def _broadcast_control_flow_decision( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + disable_all_speculation: bool = False) -> Tuple[int, bool]: + """Broadcast how many lookahead slots are scheduled for this step, and + whether all speculation is disabled, to all non-driver workers. + + This is required as if the number of draft model runs changes + dynamically, the non-driver workers won't know unless we perform a + communication to inform then. + + Returns the broadcasted num_lookahead_slots and disable_all_speculation. + """ + + if self.rank == self._driver_rank: + assert execute_model_req is not None + + broadcast_dict = dict( + num_lookahead_slots=execute_model_req.num_lookahead_slots, + disable_all_speculation=disable_all_speculation, + ) + broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) + else: + assert execute_model_req is None + broadcast_dict = broadcast_tensor_dict(src=self._driver_rank) + + return (broadcast_dict["num_lookahead_slots"], + broadcast_dict["disable_all_speculation"]) + @torch.inference_mode() def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: """Perform speculative decoding on the input batch. """ - assert execute_model_req.seq_group_metadata_list is not None, ( - "speculative decoding " - "requires non-None seq_group_metadata_list") + disable_all_speculation = False + if self.rank == self._driver_rank: + disable_all_speculation = self._should_disable_all_speculation( + execute_model_req) + + (num_lookahead_slots, + disable_all_speculation) = self._broadcast_control_flow_decision( + execute_model_req, disable_all_speculation) + + if self.rank == self._driver_rank: + assert execute_model_req is not None + assert execute_model_req.seq_group_metadata_list is not None, ( + "speculative decoding requires non-None seq_group_metadata_list" + ) + + self._maybe_disable_speculative_tokens( + disable_all_speculation, + execute_model_req.seq_group_metadata_list) + + # If no spec tokens, call the proposer and scorer workers normally. + # Used for prefill. + if num_lookahead_slots == 0 or len( + execute_model_req.seq_group_metadata_list) == 0: + return self._run_no_spec(execute_model_req, + skip_proposer=disable_all_speculation) + + return self._run_speculative_decoding_step(execute_model_req, + num_lookahead_slots) + else: + self._run_non_driver_rank(num_lookahead_slots) + return [] + def _should_disable_all_speculation( + self, execute_model_req: ExecuteModelRequest) -> bool: # When the batch size is too large, disable speculative decoding # to stop trading off throughput for latency. - disable_all = (execute_model_req.running_queue_size >= - self.disable_by_batch_size) - if disable_all: - for seq_group_metadata in execute_model_req.seq_group_metadata_list: - # Once num_speculative_tokens is set to 0, the spec decode - # of this request will be disabled forever. - # TODO(comaniac): We currently store spec decoding specific - # state in the global data structure, but we should maintain - # this state within spec decode worker. - seq_group_metadata.num_speculative_tokens = 0 - - # If no spec tokens, call the proposer and scorer workers normally. - # This happens for prefill, or when the spec decode is disabled - # for this batch. - if execute_model_req.num_lookahead_slots == 0 or len( - execute_model_req.seq_group_metadata_list) == 0: - return self._run_no_spec(execute_model_req, - skip_proposer=disable_all) - - return self._run_speculative_decoding_step(execute_model_req) + disable_all_speculation = (execute_model_req.running_queue_size >= + self.disable_by_batch_size) + + return disable_all_speculation + + def _maybe_disable_speculative_tokens( + self, disable_all_speculation: bool, + seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: + if not disable_all_speculation: + return + + for seq_group_metadata in seq_group_metadata_list: + # Once num_speculative_tokens is set to 0, the spec decode + # of this request will be disabled forever. + # TODO(comaniac): We currently store spec decoding specific + # state in the global data structure, but we should maintain + # this state within spec decode worker. + seq_group_metadata.num_speculative_tokens = 0 @nvtx_range("spec_decode_worker._run_no_spec") def _run_no_spec(self, execute_model_req: ExecuteModelRequest, @@ -252,10 +346,28 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, sampler_output.logprobs = None return [sampler_output] + def _run_non_driver_rank(self, num_lookahead_slots: int) -> None: + """Run proposer and verifier model in non-driver workers. This is used + for both speculation cases (num_lookahead_slots>0) and non-speculation + cases (e.g. prefill). + """ + # In non-driver workers the input is None + execute_model_req = None + + # Even if num_lookahead_slots is zero, we want to run the proposer model + # as it may have KV. + # + # We run the proposer once per lookahead slot. In the future we should + # delegate how many times it runs to the proposer. + for _ in range(max(num_lookahead_slots, 1)): + self.proposer_worker.execute_model(execute_model_req) + + self.scorer_worker.execute_model(execute_model_req) + @nvtx_range("spec_decode_worker._run_speculative_decoding_step") def _run_speculative_decoding_step( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest, + num_lookahead_slots: int) -> List[SamplerOutput]: """Execute a single step of speculative decoding. This invokes the proposer worker to get k speculative tokens for each @@ -264,6 +376,7 @@ def _run_speculative_decoding_step( Returns a list of SamplerOutput, each containing a single token per sequence. """ + assert num_lookahead_slots == execute_model_req.num_lookahead_slots # Generate proposals using draft worker. proposals = self.proposer_worker.get_spec_proposals(execute_model_req) @@ -455,6 +568,10 @@ def rank(self): def device(self): return self.scorer_worker.device + @property + def _driver_rank(self) -> int: + return 0 + def get_cache_block_size_bytes(self): """Return the size of a cache block in bytes. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index faea50fbfbf50..97b3873b2a9f6 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,7 +8,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + SpeculativeConfig, VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment, @@ -43,6 +43,7 @@ def __init__( distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, is_driver_worker: bool = False, ) -> None: self.model_config = model_config diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index fb32feaca0c94..1f04f821eb0f0 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -121,7 +121,7 @@ def update_environment_variables(envs: Dict[str, str]) -> None: def init_worker(self, *args, **kwargs): """ Actual initialization of the worker class, and set up - function tracing if required. + function tracing if required. Arguments are passed to the worker class constructor. """ enable_trace_function_call_for_thread() From 5c342570d7e4c73bfaa4c4057174b92117d322bb Mon Sep 17 00:00:00 2001 From: alexm-nm <59768536+alexm-nm@users.noreply.github.com> Date: Thu, 16 May 2024 09:36:49 -0400 Subject: [PATCH 055/124] Add marlin unit tests and marlin benchmark script (#4815) --- benchmarks/kernels/benchmark_marlin.py | 183 ++++++++++++++++++ benchmarks/kernels/benchmark_shapes.py | 75 +++++++ tests/kernels/test_marlin_gemm.py | 158 +++++++++++++++ .../layers/quantization/utils/__init__.py | 0 .../layers/quantization/utils/marlin_utils.py | 174 +++++++++++++++++ .../layers/quantization/utils/quant_utils.py | 146 ++++++++++++++ 6 files changed, 736 insertions(+) create mode 100644 benchmarks/kernels/benchmark_marlin.py create mode 100644 benchmarks/kernels/benchmark_shapes.py create mode 100644 tests/kernels/test_marlin_gemm.py create mode 100644 vllm/model_executor/layers/quantization/utils/__init__.py create mode 100644 vllm/model_executor/layers/quantization/utils/marlin_utils.py create mode 100644 vllm/model_executor/layers/quantization/utils/quant_utils.py diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py new file mode 100644 index 0000000000000..5dcffc284f3d4 --- /dev/null +++ b/benchmarks/kernels/benchmark_marlin.py @@ -0,0 +1,183 @@ +import argparse + +import torch +import torch.utils.benchmark as benchmark +from benchmark_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MarlinWorkspace, marlin_quantize) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + gptq_pack, quantize_weights, sort_weights) + +DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] + +ACT_ORDER_OPTS = [False, True] +K_FULL_OPTS = [False, True] + + +def bench_run(results, model, act_order, is_k_full, num_bits, group_size, + size_m, size_k, size_n): + label = "Quant Matmul" + + sub_label = ("{}, act={} k_full={}, b={}, g={}, " + "MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits, + group_size, size_m, size_k, size_n)) + + print(f"Testing: {sub_label}") + + a = torch.randn(size_m, size_k).to(torch.half).cuda() + b = torch.rand(size_k, size_n).to(torch.half).cuda() + + a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda()) + + # Marlin quant + ( + marlin_w_ref, + marlin_q_w, + marlin_s, + marlin_g_idx, + marlin_sort_indices, + marlin_rand_perm, + ) = marlin_quantize(b, num_bits, group_size, act_order) + + # GPTQ quant + (w_ref, q_w, s, g_idx, + rand_perm) = quantize_weights(b, num_bits, group_size, act_order) + q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + + # For act_order, sort the "weights" and "g_idx" + # so that group ids are increasing + repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) + if act_order: + (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) + + # Prepare + marlin_workspace = MarlinWorkspace(size_n) + + globals = { + "marlin_w_ref": marlin_w_ref, + "marlin_q_w": marlin_q_w, + "marlin_s": marlin_s, + "marlin_g_idx": marlin_g_idx, + "marlin_sort_indices": marlin_sort_indices, + "marlin_rand_perm": marlin_rand_perm, + "q_w_gptq": q_w_gptq, + "repack_sort_indices": repack_sort_indices, + "num_bits": num_bits, + "group_size": group_size, + "size_m": size_m, + "size_n": size_n, + "size_k": size_k, + "is_k_full": is_k_full, + "a": a, + "a_tmp": a_tmp, + "gptq_marlin_gemm": ops.gptq_marlin_gemm, + "gptq_marlin_repack": ops.gptq_marlin_repack, + "marlin_workspace": marlin_workspace, + } + + min_run_time = 1 + + # Warmup pytorch + for i in range(5): + torch.matmul(a, marlin_w_ref) + + results.append( + benchmark.Timer( + stmt="torch.matmul(a, marlin_w_ref)", + globals=globals, + label=label, + sub_label=sub_label, + description="pytorch_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + results.append( + benchmark.Timer( + stmt= + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + results.append( + benchmark.Timer( + stmt= + "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_repack", + ).blocked_autorange(min_run_time=min_run_time)) + + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + results = [] + + for model in args.models: + for layer in WEIGHT_SHAPES[model]: + size_k = layer[0] + size_n = layer[1] + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for act_order in ACT_ORDER_OPTS: + for is_k_full in K_FULL_OPTS: + for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS: + for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: + if len( + args.limit_group_size + ) > 0 and group_size not in args.limit_group_size: + continue + + # For act_order, the group_size must be less than + # size_k + if act_order and (group_size == size_k + or group_size == -1): + continue + + for size_m in args.batch_sizes: + bench_run(results, model, act_order, is_k_full, + num_bits, group_size, size_m, size_k, + size_n) + + compare = benchmark.Compare(results) + compare.print() + + +# For quick benchmarking use: +# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 # noqa E501 +# +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Benchmark Marlin across specified models/shapes/batches") + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py new file mode 100644 index 0000000000000..4eeeca35a37cc --- /dev/null +++ b/benchmarks/kernels/benchmark_shapes.py @@ -0,0 +1,75 @@ +WEIGHT_SHAPES = { + "ideal": [[4 * 256 * 32, 256 * 32]], + "mistralai/Mistral-7B-v0.1/TP1": [ + [4096, 6144], + [4096, 4096], + [4096, 28672], + [14336, 4096], + ], + "mistralai/Mistral-7B-v0.1/TP2": [ + [4096, 3072], + [2048, 4096], + [4096, 14336], + [7168, 4096], + ], + "mistralai/Mistral-7B-v0.1/TP4": [ + [4096, 1536], + [1024, 4096], + [4096, 7168], + [3584, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP1": [ + [4096, 12288], + [4096, 4096], + [4096, 22016], + [11008, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP2": [ + [4096, 6144], + [2048, 4096], + [4096, 11008], + [5504, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP4": [ + [4096, 3072], + [1024, 4096], + [4096, 5504], + [2752, 4096], + ], + "meta-llama/Llama-2-13b-hf/TP1": [ + [5120, 15360], + [5120, 5120], + [5120, 27648], + [13824, 5120], + ], + "meta-llama/Llama-2-13b-hf/TP2": [ + [5120, 7680], + [2560, 5120], + [5120, 13824], + [6912, 5120], + ], + "meta-llama/Llama-2-13b-hf/TP4": [ + [5120, 3840], + [1280, 5120], + [5120, 6912], + [3456, 5120], + ], + "meta-llama/Llama-2-70b-hf/TP1": [ + [8192, 10240], + [8192, 8192], + [8192, 57344], + [28672, 8192], + ], + "meta-llama/Llama-2-70b-hf/TP2": [ + [8192, 5120], + [4096, 8192], + [8192, 28672], + [14336, 8192], + ], + "meta-llama/Llama-2-70b-hf/TP4": [ + [8192, 2560], + [2048, 8192], + [8192, 14336], + [7168, 8192], + ], +} diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py new file mode 100644 index 0000000000000..b0ad85c25c572 --- /dev/null +++ b/tests/kernels/test_marlin_gemm.py @@ -0,0 +1,158 @@ +"""Tests for the marlin kernel. + +Run `pytest tests/kernels/marlin/test_marlin_gemm.py`. +""" +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MarlinWorkspace, is_marlin_supported, marlin_quantize, marlin_weights) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + gptq_pack, quantize_weights, sort_weights) + +ACT_ORDER_OPTS = [False, True] +K_FULL_OPTS = [False, True] + +K_CHUNKS = [128, 256] +N_CHUNKS = [64, 128, 256] + +MNK_FACTORS = [ + (1, 1, 1), + (1, 4, 8), + (1, 7, 5), + (1, 7 * 4, 5 * 1), + (13, 17, 67), + (26, 37, 13), + (67, 13, 11), +] + + +def rand_data(shape): + data = torch.rand(shape).to(torch.half).cuda() + return data + + +@pytest.mark.skipif(not is_marlin_supported(), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", K_CHUNKS) +@pytest.mark.parametrize("n_chunk", N_CHUNKS) +@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, + mnk_factors): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + print(f"MNK = {size_m} {size_n} {size_k}") + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size == size_k: + return + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Create input + b_weight = rand_data((size_k, size_n)) + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = quantize_weights(b_weight, num_bits, + group_size, act_order) + + # Pack to GPTQ format + q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Pack to Marlin format + marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits) + + # Run Marlin repack GPU kernel + marlin_q_w_2 = ops.gptq_marlin_repack( + q_w_gptq, + sort_indices, + size_k, + size_n, + num_bits, + ) + torch.cuda.synchronize() + + assert torch.allclose(marlin_q_w_1, marlin_q_w_2) + + +@pytest.mark.skipif(not is_marlin_supported(), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", K_CHUNKS) +@pytest.mark.parametrize("n_chunk", N_CHUNKS) +@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) +@pytest.mark.parametrize("is_k_full", K_FULL_OPTS) +def test_marlin_gemm( + k_chunk, + n_chunk, + num_bits, + group_size, + mnk_factors, + act_order, + is_k_full, +): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + print(f"MNK = {size_m} {size_n} {size_k}") + print(f"groupsize = {group_size}") + + if act_order: + if group_size == -1: + return + if group_size == size_k: + return + + a_input = rand_data((size_m, size_k)) + b_weight = rand_data((size_k, size_n)) + + w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + b_weight, num_bits, group_size, act_order) + + workspace = MarlinWorkspace(size_n) + + output = ops.gptq_marlin_gemm( + a_input, + marlin_q_w, + marlin_s, + g_idx, + sort_indices, + workspace.scratch, + num_bits, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + is_k_full, + ) + output_ref = torch.matmul(a_input, w_ref) + + torch.cuda.synchronize() + + assert torch.allclose(output, output_ref, rtol=1e-2) diff --git a/vllm/model_executor/layers/quantization/utils/__init__.py b/vllm/model_executor/layers/quantization/utils/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py new file mode 100644 index 0000000000000..33b3169983475 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -0,0 +1,174 @@ +"""This file is used for /tests and /benchmarks""" +import numpy +import torch + +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_TILE) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_pack_factor, quantize_weights, sort_weights) + +__cuda_arch = torch.cuda.get_device_capability() + + +def is_marlin_supported(): + return __cuda_arch[0] >= 8 + + +# Precompute permutations for Marlin weight and scale shuffling # noqa: E501 +# +# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 +# with the tensor-core format that is described here: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 +# +# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 +# (without the need to use ldmatrix instructions) # noqa: E501 +def _get_perms(num_bits): + perm_list = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm, scale_perm, scale_perm_single + + +_perm = {} +_scale_perm = {} +_scale_perm_single = {} +for num_bits in [4, 8]: + perm, scale_perm, scale_perm_single = _get_perms(num_bits) + _perm[num_bits] = perm + _scale_perm[num_bits] = scale_perm + _scale_perm_single[num_bits] = scale_perm_single + + +def marlin_permute_weights(q_w, + size_k, + size_n, + num_bits, + tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape( + (-1, _perm[num_bits].numel()))[:, _perm[num_bits]].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, num_bits) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def marlin_permute_scales(s, size_k, size_n, group_size, num_bits): + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(_scale_perm[num_bits])))[:, + _scale_perm[num_bits]] + else: + s = s.reshape( + (-1, + len(_scale_perm_single[num_bits])))[:, + _scale_perm_single[num_bits]] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, + act_order: bool, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, + act_order) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, num_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +class MarlinWorkspace: + + def __init__(self, out_features): + assert (out_features % GPTQ_MARLIN_MIN_THREAD_N == 0), ( + "out_features = {} is undivisible by GPTQ_MARLIN_MIN_THREAD_N = {}" + .format(out_features, GPTQ_MARLIN_MIN_THREAD_N)) + + max_workspace_size = ((out_features // GPTQ_MARLIN_MIN_THREAD_N) * + GPTQ_MARLIN_MAX_PARALLEL) + + self.scratch = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda") diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py new file mode 100644 index 0000000000000..177cb23f63cf4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -0,0 +1,146 @@ +"""This file is used for /tests and /benchmarks""" +import numpy +import torch + +SUPPORTED_NUM_BITS = [4, 8] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +def get_pack_factor(num_bits): + assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size, ), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, + act_order: bool): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Reshape to [groupsize, -1] + if group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + s = torch.max(torch.abs(w), 0, keepdim=True)[0] + s *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s + + # Restore original shapes + if group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + s = s.reshape((-1, size_n)).contiguous() + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k) + + w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to( + dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res From 99caa4910651754f3f68de518ca42349c8c424d1 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Thu, 16 May 2024 21:55:29 +0800 Subject: [PATCH 056/124] [Kernel] add bfloat16 support for gptq marlin kernel (#4788) --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 240 +++++++++++++----- .../gptq_marlin/gptq_marlin_dtypes.cuh | 62 +++++ tests/models/test_gptq_marlin.py | 9 +- .../layers/quantization/gptq_marlin.py | 8 +- 4 files changed, 246 insertions(+), 73 deletions(-) create mode 100644 csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 9c6bff000e916..fdc0ebef4672e 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -20,6 +20,11 @@ */ #include "gptq_marlin.cuh" +#include "gptq_marlin_dtypes.cuh" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) static_assert(\ + std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); template inline std::string str(T x) { return std::to_string(x); } @@ -32,7 +37,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, int4 *__restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) {} -template ; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. -__device__ inline void mma(const FragA &a_frag, const FragB &frag_b, - FragC &frag_c) { +template +__device__ inline void mma(const typename ScalarType::FragA &a_frag, + const typename ScalarType::FragB &frag_b, + typename ScalarType::FragC &frag_c) { const uint32_t *a = reinterpret_cast(&a_frag); const uint32_t *b = reinterpret_cast(&frag_b); float *c = reinterpret_cast(&frag_c); - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), - "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + if constexpr (std::is_same::value) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), + "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), + "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { +template +__device__ inline void ldsm4(typename ScalarType::FragA &frag_a, const void *smem_ptr) { uint32_t *a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" @@ -129,8 +140,15 @@ __device__ inline uint32_t prmt(uint32_t a) { // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 // values. We mostly follow the strategy in the link below, with some small // changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant_4bit(int q) { +// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +template +__device__ inline typename ScalarType::FragB dequant_4bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_4bit(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -142,7 +160,7 @@ __device__ inline FragB dequant_4bit(int q) { const int SUB = 0x64086408; const int MUL = 0x2c002c00; const int ADD = 0xd480d480; - FragB frag_b; + typename ScalarType::FragB frag_b; frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); frag_b[1] = __hfma2(*reinterpret_cast(&hi), @@ -151,7 +169,41 @@ __device__ inline FragB dequant_4bit(int q) { return frag_b; } -__device__ inline FragB dequant_8bit(int q) { +template <> +__device__ inline typename ScalarType::FragB dequant_4bit(int q) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + typename ScalarType::FragB frag_b; + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC308C308; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or bf16 +// Reference: +// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +template +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t start_byte_for_fp16 = 0x64646464; @@ -161,7 +213,7 @@ __device__ inline FragB dequant_8bit(int q) { static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - FragB frag_b; + typename ScalarType::FragB frag_b; frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); frag_b[1] = __hsub2(*reinterpret_cast(&hi), @@ -169,34 +221,69 @@ __device__ inline FragB dequant_8bit(int q) { return frag_b; } +template <> +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + typename ScalarType::FragB frag_b; + + float fp32_intermediates[4]; + uint32_t * fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); + + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); +template +__device__ inline void scale(typename ScalarType::FragB &frag_b, + typename ScalarType::FragS &frag_s, int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } // Same as above, but for act_order (each K is multiplied individually) -__device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2, - FragS &frag_s_3, FragS &frag_s_4, int i) { - __half2 s_val_1_2; - s_val_1_2.x = reinterpret_cast<__half *>(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast<__half *>(&frag_s_2)[i]; - - __half2 s_val_3_4; - s_val_3_4.x = reinterpret_cast<__half *>(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast<__half *>(&frag_s_4)[i]; +template +__device__ inline void scale4(typename ScalarType::FragB &frag_b, + typename ScalarType::FragS &frag_s_1, + typename ScalarType::FragS &frag_s_2, + typename ScalarType::FragS &frag_s_3, + typename ScalarType::FragS &frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; frag_b[0] = __hmul2(frag_b[0], s_val_1_2); frag_b[1] = __hmul2(frag_b[1], s_val_3_4); } // Given 2 floats multiply by 2 scales (halves) -__device__ inline void scale_float(float *c, FragS &s) { - __half *s_ptr = reinterpret_cast<__half *>(&s); - c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +template +__device__ inline void scale_float(float *c, typename ScalarType::FragS &s) { + scalar_t *s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); } // Wait until barrier reaches `count`, then lock for current threadblock. @@ -287,7 +374,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, } } -template ; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; constexpr int pack_factor = 32 / num_bits; @@ -691,7 +785,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int4 *sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); int4 *sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll @@ -835,43 +929,43 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int b_quant = frag_b_quant[k % 2][0][j]; int b_quant_shift = b_quant >> 8; - frag_b0 = dequant_4bit(b_quant); - frag_b1 = dequant_4bit(b_quant_shift); + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); } else { int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); } // Apply scale to frag_b0 if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); } else { if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); + scale(frag_b0, frag_s[k % 2][j], 0); } } // Apply scale to frag_b1 if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); } else { if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); + scale(frag_b1, frag_s[k % 2][j], 1); } } #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); } } }; @@ -979,15 +1073,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk for (int j = 0; j < 2 * 4; j++) { reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half *>(&c_red)[j]); + Dtype::num2float(reinterpret_cast(&c_red)[j]); } } if (!last) { int4 c; #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half *>(&c)[j] = - __float2half(reinterpret_cast( + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); } C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = @@ -1022,7 +1116,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // We first reorder in shared memory to guarantee the most efficient final // global write patterns auto write = [&](int idx, float c0, float c1, FragS &s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); // For per-column quantization we finally apply the scale here (only for // 4-bit) @@ -1030,7 +1124,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk res = __hmul2(res, s[0]); } - ((half2 *)sh)[idx] = res; + ((scalar_t2 *)sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { @@ -1192,14 +1286,14 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); } } @@ -1255,10 +1349,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ num_threads == NUM_THREADS) { \ cudaFuncSetAttribute( \ - Marlin, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ <<>>( \ A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ @@ -1462,6 +1556,7 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +template void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, void *g_idx, void *perm, void *a_tmp, int prob_m, int prob_n, int prob_k, void *workspace, int num_bits, @@ -1731,14 +1826,25 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); - gptq_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, gptq_marlin::max_par); + if (a.scalar_type() == at::ScalarType::Half) { + gptq_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, + size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, gptq_marlin::max_par); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + gptq_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, + size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, gptq_marlin::max_par); + } else { + TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); + } return c; } -#endif +#endif \ No newline at end of file diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh b/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh new file mode 100644 index 0000000000000..7881abbe4cbbf --- /dev/null +++ b/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh @@ -0,0 +1,62 @@ + +#ifndef _data_types_cuh +#define _data_types_cuh +#include "gptq_marlin.cuh" +#include +#include + + +namespace gptq_marlin { + +template +class ScalarType { +}; + +template <> +class ScalarType { +public: + using scalar_t = half; + using scalar_t2 = half2; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + + static __device__ float inline num2float(const half x) { return __half2float(x); } + + static __device__ half2 inline num2num2(const half x) { return __half2half2(x); } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); } + + static __host__ __device__ half inline float2num(const float x) { return __float2half(x); } +}; + +template <> +class ScalarType { +public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); } +#endif +}; + +} + +#endif diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py index db55d4488a374..1fc0b3f239127 100644 --- a/tests/models/test_gptq_marlin.py +++ b/tests/models/test_gptq_marlin.py @@ -14,6 +14,7 @@ import torch from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT from .utils import check_logprobs_close @@ -52,7 +53,7 @@ @pytest.mark.skipif(gptq_marlin_not_supported, reason="gptq_marlin is not supported on this GPU type.") @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", ["half", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models( @@ -76,11 +77,15 @@ def test_models( gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( example_prompts[:-1], max_tokens, num_logprobs) del gptq_marlin_model + _ROPE_DICT.clear() # clear rope cache to avoid rope dtype error # Run gptq. + # The naive gptq kernel doesn't support bf16 yet. + # Here we always compare fp16/bf16 gpt marlin kernel + # to fp16 gptq kernel. gptq_model = vllm_runner(model_name=model_name, revision=revision, - dtype=dtype, + dtype="half", quantization="gptq", max_model_len=MAX_MODEL_LEN, tensor_parallel_size=1) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index e2464008a875f..354bb55d09e24 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -99,7 +99,7 @@ def get_name(cls) -> str: @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.half] + return [torch.half, torch.bfloat16] @classmethod def get_min_capability(cls) -> int: @@ -186,9 +186,9 @@ def create_weights( group_size = input_size # Validate dtype - if params_dtype != torch.float16: - raise ValueError( - f"The params dtype must be float16, but got {params_dtype}") + if params_dtype not in [torch.float16, torch.bfloat16]: + raise ValueError(f"The params dtype must be float16 " + f"or bfloat16, but got {params_dtype}") # Validate output_size_per_partition output_size_per_partition = sum(output_partition_sizes) From dbc0754ddfc33e60454ec4a9cdba945f172a39ef Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Thu, 16 May 2024 11:42:17 -0400 Subject: [PATCH 057/124] [docs] Fix typo in examples filename openi -> openai (#4864) --- .../{openi_example_batch.jsonl => openai_example_batch.jsonl} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{openi_example_batch.jsonl => openai_example_batch.jsonl} (100%) diff --git a/examples/openi_example_batch.jsonl b/examples/openai_example_batch.jsonl similarity index 100% rename from examples/openi_example_batch.jsonl rename to examples/openai_example_batch.jsonl From 5e0391c0406ec225b2c58bc22d5be864a432fe40 Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Thu, 16 May 2024 11:42:41 -0400 Subject: [PATCH 058/124] [Frontend] Separate OpenAI Batch Runner usage from API Server (#4851) --- vllm/entrypoints/openai/run_batch.py | 2 +- vllm/usage/usage_lib.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 99f1b2d6d091b..731f4f4a4028a 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -101,7 +101,7 @@ async def main(args): engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_API_SERVER) + engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) # When using single vLLM without engine_use_ray model_config = await engine.get_model_config() diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 9029a5b16af72..40a954a29493e 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -90,6 +90,7 @@ class UsageContext(str, Enum): LLM_CLASS = "LLM_CLASS" API_SERVER = "API_SERVER" OPENAI_API_SERVER = "OPENAI_API_SERVER" + OPENAI_BATCH_RUNNER = "OPENAI_BATCH_RUNNER" ENGINE_CONTEXT = "ENGINE_CONTEXT" From 9216b9cc38e8753de442877863ad651425902e1f Mon Sep 17 00:00:00 2001 From: Pierre Dulac Date: Thu, 16 May 2024 18:42:21 +0200 Subject: [PATCH 059/124] [Bugfix] Bypass authorization API token for preflight requests (#4862) --- vllm/entrypoints/openai/api_server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7cd51b959a0ea..97b35262329ee 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -154,6 +154,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): @app.middleware("http") async def authentication(request: Request, call_next): root_path = "" if args.root_path is None else args.root_path + if request.method == "OPTIONS": + return await call_next(request) if not request.url.path.startswith(f"{root_path}/v1"): return await call_next(request) if request.headers.get("Authorization") != "Bearer " + token: From 6979ade3840703d055402e78168194644e3cbdd8 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Thu, 16 May 2024 12:56:15 -0400 Subject: [PATCH 060/124] Add GPTQ Marlin 2:4 sparse structured support (#4790) Co-authored-by: Robert Shaw --- CMakeLists.txt | 3 +- csrc/ops.h | 11 + csrc/pybind.cpp | 3 +- csrc/quantization/marlin/{ => dense}/LICENSE | 0 .../marlin/{ => dense}/marlin_cuda_kernel.cu | 0 csrc/quantization/marlin/sparse/LICENSE | 203 +++ csrc/quantization/marlin/sparse/common/base.h | 49 + csrc/quantization/marlin/sparse/common/mem.h | 132 ++ csrc/quantization/marlin/sparse/common/mma.h | 175 +++ .../marlin/sparse/marlin_24_cuda_kernel.cu | 1110 +++++++++++++++++ tests/models/test_gptq_marlin_24.py | 81 ++ vllm/_custom_ops.py | 10 + vllm/config.py | 48 +- .../layers/quantization/__init__.py | 11 +- .../layers/quantization/base_config.py | 11 + .../layers/quantization/gptq_marlin.py | 23 + .../layers/quantization/gptq_marlin_24.py | 280 +++++ .../layers/quantization/marlin.py | 22 + 18 files changed, 2131 insertions(+), 41 deletions(-) rename csrc/quantization/marlin/{ => dense}/LICENSE (100%) rename csrc/quantization/marlin/{ => dense}/marlin_cuda_kernel.cu (100%) create mode 100644 csrc/quantization/marlin/sparse/LICENSE create mode 100644 csrc/quantization/marlin/sparse/common/base.h create mode 100644 csrc/quantization/marlin/sparse/common/mem.h create mode 100644 csrc/quantization/marlin/sparse/common/mma.h create mode 100644 csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu create mode 100644 tests/models/test_gptq_marlin_24.py create mode 100644 vllm/model_executor/layers/quantization/gptq_marlin_24.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 1c7dfe0c048b0..2051d7560be25 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -176,7 +176,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" - "csrc/quantization/marlin/marlin_cuda_kernel.cu" + "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" + "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/custom_all_reduce.cu") diff --git a/csrc/ops.h b/csrc/ops.h index 9541adcb3de88..ef37131c962f8 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -125,6 +125,17 @@ torch::Tensor marlin_gemm( int64_t size_n, int64_t size_k); +torch::Tensor gptq_marlin_24_gemm( + torch::Tensor &a, + torch::Tensor &b_q_weight, + torch::Tensor &b_meta, + torch::Tensor &b_scales, + torch::Tensor &workspace, + int64_t num_bits, + int64_t size_m, + int64_t size_n, + int64_t size_k); + torch::Tensor gptq_marlin_gemm( torch::Tensor &a, torch::Tensor &b_q_weight, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 173e0b1732e13..0339eba70c013 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -66,7 +66,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); - ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ"); + ops.def("marlin_gemm", &marlin_gemm, "Marlin (Dense) Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"); ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); diff --git a/csrc/quantization/marlin/LICENSE b/csrc/quantization/marlin/dense/LICENSE similarity index 100% rename from csrc/quantization/marlin/LICENSE rename to csrc/quantization/marlin/dense/LICENSE diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu similarity index 100% rename from csrc/quantization/marlin/marlin_cuda_kernel.cu rename to csrc/quantization/marlin/dense/marlin_cuda_kernel.cu diff --git a/csrc/quantization/marlin/sparse/LICENSE b/csrc/quantization/marlin/sparse/LICENSE new file mode 100644 index 0000000000000..ca75fb15e660a --- /dev/null +++ b/csrc/quantization/marlin/sparse/LICENSE @@ -0,0 +1,203 @@ +Contains code from https://github.com/IST-DASLab/Sparse-Marlin/ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/csrc/quantization/marlin/sparse/common/base.h b/csrc/quantization/marlin/sparse/common/base.h new file mode 100644 index 0000000000000..929b39d7642f1 --- /dev/null +++ b/csrc/quantization/marlin/sparse/common/base.h @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace marlin_24 { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template struct Vec { + T elems[n]; + __device__ T &operator[](int i) { return elems[i]; } +}; + +template struct ShapeBase { + static constexpr int M = M_, N = N_, K = K_; +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragM = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/common/mem.h b/csrc/quantization/marlin/sparse/common/mem.h new file mode 100644 index 0000000000000..a49d15ca544eb --- /dev/null +++ b/csrc/quantization/marlin/sparse/common/mem.h @@ -0,0 +1,132 @@ +/* + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "base.h" + +namespace marlin_24 { +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred_zfill(void *smem_ptr, + const void *glob_ptr, + bool pred = true, + const bool zfill = false) { + const int BYTES = 16; + int src_in_bytes = (zfill ? 0 : BYTES); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); +} + +__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template __device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { + uint32_t *a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +__device__ inline void ldsm4_m(FragM &frag_m, const void *smem_ptr) { + uint32_t *a = reinterpret_cast(&frag_m); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) + : "r"(smem)); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4_t(FragA &frag_a, const void *smem_ptr) { + uint32_t *a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int *lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int *lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/common/mma.h b/csrc/quantization/marlin/sparse/common/mma.h new file mode 100644 index 0000000000000..9319456677d36 --- /dev/null +++ b/csrc/quantization/marlin/sparse/common/mma.h @@ -0,0 +1,175 @@ +/* + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "base.h" + +namespace marlin_24 { + +// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma_sp(const FragB &a_frag0, const FragB &a_frag1, + const FragA &frag_b, FragC &frag_c, FragM &frag_m, + const int psel) { + const uint32_t *a0 = reinterpret_cast(&a_frag0); + const uint32_t *a1 = reinterpret_cast(&a_frag1); + const uint32_t *b = reinterpret_cast(&frag_b); + const uint32_t *e = reinterpret_cast(&frag_m); + float *c = reinterpret_cast(&frag_c); + if (psel == 0) { + asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), + "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), + "f"(c[2]), "f"(c[3]), "r"(e[0])); + asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), + "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), + "f"(c[6]), "f"(c[7]), "r"(e[0])); + } else { + asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), + "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), + "f"(c[2]), "f"(c[3]), "r"(e[0])); + asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), + "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), + "f"(c[6]), "f"(c[7]), "r"(e[0])); + } +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template __device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, + float c3) { + uint2 r; + asm("{\n\t" + ".reg .f16 a, b, c, d; \n\t" + "cvt.rn.f16.f32 a, %2; \n\t" + "cvt.rn.f16.f32 b, %3; \n\t" + "cvt.rn.f16.f32 c, %4; \n\t" + "cvt.rn.f16.f32 d, %5; \n\t" + "mov.b32 %0, {a, b}; \n\t" + "mov.b32 %1, {c, d}; \n\t" + "}" + : "=r"(r.x), "=r"(r.y) + : "f"(c0), "f"(c1), "f"(c2), "f"(c3)); + return r; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant_4bit(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant_8bit(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +__device__ inline void scale_floats(float *c0, float *c1, float *c2, float *c3, + FragS &s0, float *c4, float *c5, float *c6, + float *c7, FragS &s1) { + *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); + *c1 = __fmul_rn(*c1, __half2float(s0[0].y)); + *c2 = __fmul_rn(*c2, __half2float(s0[1].x)); + *c3 = __fmul_rn(*c3, __half2float(s0[1].y)); + + *c4 = __fmul_rn(*c4, __half2float(s1[0].x)); + *c5 = __fmul_rn(*c5, __half2float(s1[0].y)); + *c6 = __fmul_rn(*c6, __half2float(s1[1].x)); + *c7 = __fmul_rn(*c7, __half2float(s1[1].y)); +} + +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu new file mode 100644 index 0000000000000..42b0566183a8d --- /dev/null +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -0,0 +1,1110 @@ +/* + * Notice: This file was modified by Neuralmagic inc to include 8-bit support + * + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include + +#include + +#include "common/base.h" + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +#else + +#include "common/mem.h" +#include "common/mma.h" + +#endif + +template inline std::string str(T x) { return std::to_string(x); } + +namespace marlin_24 { + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int THREADS = 256; +static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 128; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks with + // a separate quantization scale + > +__global__ void Marlin_24( + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4 + *__restrict__ meta, // 2bit metadata information about 2:4 format on B + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 + *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization +) {} + +torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_meta, + torch::Tensor &b_scales, + torch::Tensor &workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "gptq_marlin_24_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks with + // a separate quantization scale + > +__global__ void Marlin_24( + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4 + *__restrict__ meta, // 2bit metadata information about 2:4 format on B + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 + *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + // number of thread_k_blocks in k-dim + int k_tiles = prob_k / 32 / thread_k_blocks; + // number of thread_n_blocks in n-dim + int n_tiles = prob_n / 16 / thread_n_blocks; + // iters needed to cover all slices + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts in + // the middle of group. + if (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + // number of threadblock tiles in the current slice + int slice_iters; + // total number of active threadblocks in the current slice + int slice_count = 0; + // index of threadblock in current slice; numbered bottom to top + int slice_idx; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) + slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) + slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 32 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 32 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads //RLC: 2 * #warps k-dim + constexpr int a_sh_rd_delta_o = 4 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + constexpr int pack_factor = 32 / num_bits; + + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 + constexpr int m_sh_stride = + (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp + int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; + int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); + constexpr int m_sh_wr_delta = threads / 2; + constexpr int m_sh_rd_delta = threads / 2; + constexpr int m_sh_stage = m_sh_stride * thread_k_blocks; + constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta); + + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_sh_stage = s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) + + (threadIdx.x % (m_sh_stride)); + m_gl_rd += (m_sh_stride)*slice_col; + m_gl_rd += m_gl_rd_delta_o * slice_row; + int m_sh_wr = threadIdx.x; + int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16; + + int s_gl_rd; + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + + int s_sh_wr = threadIdx.x; + int s_sh_rd; + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + if (group_blocks != -1) { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } else { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) { + a_sh_rd_trans[0][i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + a_sh_rd_trans[1][i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2); + } + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4 *B_ptr[b_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; + const int4 *meta_ptr[m_sh_iters]; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4 *sh_a = sh; + int4 *sh_b = sh_a + (stages * a_sh_stage); + int4 *sh_s = sh_b + (stages * b_sh_stage); + int4 *sh_m = sh_s + (stages * s_sh_stage); + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks][2]; + I4 frag_b_quant[2][b_thread_vecs]; + FragM frag_m[2][2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&]() { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], + B_ptr[i] + j); + } + B_ptr[i] += b_gl_rd_delta_o; + } + int4 *sh_meta_stage = sh_m + m_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) { + if (m_sh_wr_pred) + cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], + meta_ptr[i]); + meta_ptr[i] += m_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) + cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; + // however, this does not seem to be a significant bottleneck, while some + // theoretically better attempts have lead to bad instruction ordering by + // the compiler and correspondingly a noticeable drop in performance. + if (group_blocks != -1) { + int4 *sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + ldsm4(frag_a[k % 2][i][0], + &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); + ldsm4(frag_a[k % 2][i][1], + &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); + } + + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + + // Load meta with ldsm4 + int4 *sh_m_stage = sh_m + m_sh_stage * pipe; + ldsm4_m(frag_m[k % 2][0], + &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + + if constexpr (num_bits == 4) { + int b_quant = frag_b_quant[k % 2][0][j]; + int b_quant_shift = b_quant >> 8; + + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); + + } else { + int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + } + + // If there are no groups, we can just scale the final output once and can + // avoid doing so for each weight. + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], + frag_m[k % 2][j / 2], j % 2); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + +// Parallel logarithmic shared memory reduction. We make sure to avoid any +// unnecessary read or write iterations, e.g., for two warps we write only once +// by warp 1 and read only once by warp 0. +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { +#pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float *c_rd = reinterpret_cast( + &sh[red_sh_delta * j + red_sh_rd]); + float *c_wr = reinterpret_cast(&sh[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { +#pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float *c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped partitioning + // minimizes the number of such reductions and our outputs are usually rather + // small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; + int c_gl_wr_delta_i = + c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) + int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + + 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int col = 2 * ((threadIdx.x % 32) % 4); + + if (!first) { +// Interestingly, doing direct global accesses here really seems to mess up the +// compiler and lead to slowdowns, hence we also use async-copies even though +// these fetches are not actually asynchronous. +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + col + (i % 2) < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + col + (i % 2) < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; +#pragma unroll + for (int j2 = 0; j2 < 2; j2++) { +#pragma unroll + for (int j1 = 0; j1 < 4; j1++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + + 4 * ((i % 4) / 2) + i % 2] += + __half2float( + reinterpret_cast<__half *>(&c_red)[(j2 * 4 + j1)]); + } + } + } + if (!last) { + int4 c; +#pragma unroll + for (int j2 = 0; j2 < 2; j2++) { +#pragma unroll + for (int j1 = 0; j1 < 4; j1++) { + reinterpret_cast<__half *>(&c)[(j2 * 4 + j1)] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + + 4 * ((i % 4) / 2) + i % 2]); + } + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + + constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: + constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: + constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: + + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + + int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + + ((threadIdx.x % 32) / 4); // RLC: + c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) + + constexpr int c_sh_rd_delta = + c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: + int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + + (threadIdx.x % (2 * 2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS &s0, + float c4, float c5, float c6, float c7, FragS &s1) { + uint2 res[2]; + res[0] = to_half4(c0, c1, c2, c3); + res[1] = to_half4(c4, c5, c6, c7); + half2 *tmp = (half2 *)&res; + // for per-column quantization we finally apply the scale here + if constexpr (group_blocks == -1 && num_bits == 4) { + tmp[0] = __hmul2(tmp[0], s0[0]); + tmp[1] = __hmul2(tmp[1], s0[1]); + tmp[2] = __hmul2(tmp[2], s1[0]); + tmp[3] = __hmul2(tmp[3], s1[1]); + } + ((int4 *)sh)[idx] = *((int4 *)&res[0]); + }; + + // RLC: only warp 0 and 1 baseline example + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + int wr = c_sh_wr; + write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], + frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2], + frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2], + frag_s[0][2]); + write(wr + c_sh_stride, frag_c[i][0][0][1], frag_c[i][1][0][1], + frag_c[i][2][0][1], frag_c[i][3][0][1], frag_s[0][0], + frag_c[i][0][0][3], frag_c[i][1][0][3], frag_c[i][2][0][3], + frag_c[i][3][0][3], frag_s[0][2]); + write(wr + 4 * c_sh_stride_2, frag_c[i][0][1][0], frag_c[i][1][1][0], + frag_c[i][2][1][0], frag_c[i][3][1][0], frag_s[0][0], + frag_c[i][0][1][2], frag_c[i][1][1][2], frag_c[i][2][1][2], + frag_c[i][3][1][2], frag_s[0][2]); + write(wr + 4 * c_sh_stride_2 + c_sh_stride, frag_c[i][0][1][1], + frag_c[i][1][1][1], frag_c[i][2][1][1], frag_c[i][3][1][1], + frag_s[0][0], frag_c[i][0][1][3], frag_c[i][1][1][3], + frag_c[i][2][1][3], frag_c[i][3][1][3], frag_s[0][2]); + + c_sh_wr += 8 * c_sh_stride_2; + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { +#pragma unroll + for (int i = 0; i < stages - 1; i++) + fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { +// We unroll over both the global fetch and the register load pipeline to ensure +// all shared memory accesses are static. Note that both pipelines have even +// length meaning that the next iteration will always start at index 0. +#pragma unroll + for (int pipe = 0; pipe < stages;) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + wait_for_stage(); + + fetch_to_registers(pipe + 1, (pipe + 1) % stages); + matmul(pipe); + + pipe++; + slice_iters--; + if (slice_iters == 0) + break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (group_blocks == -1) { + if constexpr (num_bits == 8) { + if (s_sh_wr_pred) + cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } else { + if (last) { + if (s_sh_wr_pred) + cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } + } + } + thread_block_reduce(); + + if constexpr (group_blocks == -1) { + if constexpr (num_bits == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + *(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); + } + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + *(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (group_blocks == -1 && num_bits == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0], + &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0], + &frag_c[i][0][0][2], &frag_c[i][1][0][2], + &frag_c[i][2][0][2], &frag_c[i][3][0][2], + frag_s[0][2]); + + scale_floats(&frag_c[i][0][0][1], &frag_c[i][1][0][1], + &frag_c[i][2][0][1], &frag_c[i][3][0][1], frag_s[0][0], + &frag_c[i][0][0][3], &frag_c[i][1][0][3], + &frag_c[i][2][0][3], &frag_c[i][3][0][3], + frag_s[0][2]); + + scale_floats(&frag_c[i][0][1][0], &frag_c[i][1][1][0], + &frag_c[i][2][1][0], &frag_c[i][3][1][0], frag_s[0][0], + &frag_c[i][0][1][2], &frag_c[i][1][1][2], + &frag_c[i][2][1][2], &frag_c[i][3][1][2], + frag_s[0][2]); + + scale_floats(&frag_c[i][0][1][1], &frag_c[i][1][1][1], + &frag_c[i][2][1][1], &frag_c[i][3][1][1], frag_s[0][0], + &frag_c[i][0][1][3], &frag_c[i][1][1][3], + &frag_c[i][2][1][3], &frag_c[i][3][1][3], + frag_s[0][2]); + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] -= m_gl_stride; + } + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + +#endif + +#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS) { \ + cudaFuncSetAttribute( \ + Marlin_24, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin_24 \ + <<>>(A_ptr, B_ptr, meta_ptr, \ + C_ptr, s_ptr, prob_n, \ + prob_m, prob_k, locks); \ + } + +void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, + void *s, int prob_m, int prob_n, int prob_k, + void *workspace, int num_bits, int groupsize = -1, + int dev = 0, cudaStream_t stream = 0, int thread_k = -1, + int thread_m = -1, int sms = -1, int max_par = 16) { + int tot_n = prob_n; + int tot_n_blocks = ceildiv(tot_n, 16); + int pad = 16 * tot_n_blocks - tot_n; + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + TORCH_CHECK(sms > 0); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + if (thread_k == -1 || thread_m == -1) { + if (prob_n <= 16) { + // For small batchizes, better partitioningif is slightly more important than + // better compute utilization + thread_k = 128; + thread_m = 128; + } else { + thread_k = 64; + thread_m = 256; + } + } + + int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction + int thread_m_blocks = thread_m / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + TORCH_CHECK(prob_m % thread_m == 0, "prob_m = ", prob_m, + " is not divisible by thread_m = ", thread_m); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + if (group_blocks != -1) { + TORCH_CHECK((prob_k / 2) % group_blocks == 0, "prob_k/2 = ", prob_k / 2, + " is not divisible by group_blocks = ", group_blocks); + } + + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + const int4 *A_ptr = (const int4 *)A; + const int4 *B_ptr = (const int4 *)B; + const int4 *meta_ptr = (const int4 *)meta; + int4 *C_ptr = (int4 *)C; + const int4 *s_ptr = (const int4 *)s; + + int *locks = (int *)workspace; + for (int i = 0; i < tot_n_blocks; i += 4) { + int thread_n_blocks = tot_n_blocks - i; + prob_n = tot_n - 16 * i; + int par = 1; + if (thread_n_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_n_blocks - pad) / 64; + if (par > max_par) + par = max_par; + prob_n = 64 * par; + i += 4 * (par - 1); + thread_n_blocks = 4; + } + + // For compilation speed, we only define the kernel configurations that have + // seemed useful (in terms of performance) in our testing, however many more + // are, in principle, possible. + + // the false is start of the CALL_IF macros + if (false) { + } // BMxBNxBK, group + // 4-bit + CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(4, 16, 2, 2, 4) + CALL_IF_2_4(4, 16, 3, 2, -1) + CALL_IF_2_4(4, 16, 3, 2, 4) + CALL_IF_2_4(4, 16, 4, 2, -1) + CALL_IF_2_4(4, 16, 4, 2, 4) + + // 8-bit + CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(8, 16, 2, 2, 4) + CALL_IF_2_4(8, 16, 3, 2, -1) + CALL_IF_2_4(8, 16, 3, 2, 4) + CALL_IF_2_4(8, 16, 4, 2, -1) + CALL_IF_2_4(8, 16, 4, 2, 4) + else { + throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + + ", " + str(prob_k) + ", " + str(prob_n) + "]" + + ", groupsize = " + str(groupsize) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_n_blocks * (prob_m / 8) * par; + } +} + +} // namespace marlin_24 + +torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_meta, + torch::Tensor &b_scales, + torch::Tensor &workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k) { + // Verify num_bits + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int pack_factor = 32 / num_bits; + + // Verify M + TORCH_CHECK(size_m == a.size(0), + "Shape mismatch: a.size(0) = " + str(a.size(0)) + + ", size_m = " + str(size_m)); + + // Verify K + TORCH_CHECK(size_k == a.size(1), + "Shape mismatch: a.size(1) = " + str(a.size(1)) + + ", size_k = " + str(size_k)); + TORCH_CHECK(size_k % marlin_24::tile_size == 0, + "size_k = " + str(size_k) + " is not divisible by tile_size = " + + str(marlin_24::tile_size)); + TORCH_CHECK((size_k / marlin_24::tile_size / 2) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = " + + str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + + ", tile_size = " + str(marlin_24::tile_size)); + + // Verify N + TORCH_CHECK(b_scales.size(1) == size_n, + "b_scales.size(1) = " + str(b_scales.size(1)) + + ", size_n = " + str(size_n)); + TORCH_CHECK( + b_q_weight.size(1) % marlin_24::tile_size == 0, + "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + + " is not divisible by tile_size = " + str(marlin_24::tile_size)); + + int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, + "size_n = " + str(size_n) + + ", actual_size_n = " + str(actual_size_n)); + + // Verify meta + TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2, + "b_meta.size(0) = ", b_meta.size(0), + " is not size_k / 8 / 2 / 2 = ", size_k / 8 / 2 / 2); + TORCH_CHECK(b_meta.size(1) == size_n * 2, "b_meta.size(1) = ", b_meta.size(1), + " is not size_n * 2 = ", size_n * 2); + + // Verify A device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + // Verify B device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + // Verify b_meta device and strides + TORCH_CHECK(b_meta.device().is_cuda(), "b_meta is not on GPU"); + TORCH_CHECK(b_meta.is_contiguous(), "b_meta is not contiguous"); + + // Verify scales device and strides + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // Alloc C matrix + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({size_m, size_n}, options); + + int thread_k = -1; + int thread_m = -1; + int sms = -1; + int max_par = 16; + + int groupsize = -1; + if (b_scales.size(0) > 1) { + TORCH_CHECK(size_k % b_scales.size(0) == 0, + "size_k = " + str(size_k) + + ", is not divisible by b_scales.size(0) = " + + str(b_scales.size(0))); + groupsize = size_k / b_scales.size(0); + groupsize /= 2; // Because of 24 + } + + // Verify groupsize + TORCH_CHECK(groupsize == -1 || groupsize == 64, + "Unexpected groupsize = " + str(groupsize)); + + // Verify workspace size + TORCH_CHECK(size_n % marlin_24::min_thread_n == 0, + "size_n = " + str(size_n) + + ", is not divisible by min_thread_n = " + + str(marlin_24::min_thread_n)); + int min_workspace_size = + (size_n / marlin_24::min_thread_n) * marlin_24::max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = " + str(workspace.numel()) + + " is below min_workspace_size = " + str(min_workspace_size)); + + int dev = a.get_device(); + marlin_24::marlin_cuda_2_4( + a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(), + num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + thread_m, sms, max_par); + + return c; +} diff --git a/tests/models/test_gptq_marlin_24.py b/tests/models/test_gptq_marlin_24.py new file mode 100644 index 0000000000000..3e6ffb7f90fcc --- /dev/null +++ b/tests/models/test_gptq_marlin_24.py @@ -0,0 +1,81 @@ +"""Compare the outputs of a GPTQ model to a Marlin_24 model. + +Note: GPTQ and Marlin_24 do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +Marlin/GPTQ models are in the top 3 selections of each other. + +Run `pytest tests/models/test_marlin_24.py`. +""" +from dataclasses import dataclass + +import pytest +import torch + +from tests.models.utils import check_logprobs_close +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] +marlin_not_supported = (capability < + QUANTIZATION_METHODS["marlin"].get_min_capability()) + + +@dataclass +class ModelPair: + model_marlin: str + model_gptq: str + + +model_pairs = [ + # 4-bit, group_size == 128 + ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-g128", + model_gptq="alexm-nm/tinyllama-24-gptq-4bit-g128"), + # 4-bit, group_size == channelwise + ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-channelwise", + model_gptq="alexm-nm/tinyllama-24-gptq-4bit-channelwise"), + + # 8-bit, group_size == 128 + ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-g128", + model_gptq="alexm-nm/tinyllama-24-gptq-8bit-g128"), + # 8-bit, group_size == channelwise + ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-channelwise", + model_gptq="alexm-nm/tinyllama-24-gptq-8bit-channelwise"), +] + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(marlin_not_supported, + reason="Marlin24 is not supported on this GPU type.") +@pytest.mark.parametrize("model_pair", model_pairs) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [8]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model_pair: ModelPair, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + marlin_24_model = vllm_runner(model_pair.model_marlin, + dtype=dtype, + quantization="gptq_marlin_24") + marlin_24_outputs = marlin_24_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + del marlin_24_model + + gptq_model = vllm_runner(model_pair.model_gptq, + dtype=dtype, + quantization="gptq") + gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, + max_tokens, + num_logprobs) + del gptq_model + + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=marlin_24_outputs, + name_0="gptq", + name_1="marlin_24", + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 42dedfdf76c4f..95baa84262658 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -153,6 +153,16 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, size_n, size_k) +# marlin_24 +def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, num_bits: int, size_m: int, + size_n: int, size_k: int) -> torch.Tensor: + return vllm_ops.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, + workspace, num_bits, size_m, size_n, + size_k) + + # aqlm def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, diff --git a/vllm/config.py b/vllm/config.py index 91f590aaf79eb..77ce8c318d8f1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -7,14 +7,11 @@ from transformers import PretrainedConfig from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, - get_quantization_config) +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron -GPTQMarlinConfig = get_quantization_config("gptq_marlin") - if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -155,37 +152,15 @@ def _verify_quantization(self) -> None: quant_cfg = getattr(self.hf_config, "quantization_config", None) if quant_cfg is not None: quant_method = quant_cfg.get("quant_method", "").lower() - # compat: autogptq >=0.8.0 use checkpoint_format: str - # compat: autogptq <=0.7.1 is_marlin_format: bool - is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin" - or quant_cfg.get("is_marlin_format", False)) - - # Check which LinearMethod the GPTQ model should use. - if quant_method == "gptq": - # If serialized in Marlin format, use MarlinLinearMethod. - # TODO (@robertgshaw): migrate under GPTQMarlinLinearMethod. - if is_format_marlin: - logger.info("The model is serialized in Marlin format. " - "Using Marlin kernel.") - quant_method = "marlin" - if self.quantization == "gptq": - self.quantization = quant_method - - # If convertible to Marlin format, use GPTQMarlinLinearMethod - # unless the user explicitly specified GPTQLinearMethod. - elif GPTQMarlinConfig.is_marlin_compatible(quant_cfg): - if self.quantization == "gptq": - logger.warning( - "The model is convertible to Marlin format, but " - "you specified quantization=gptq. Use " - "quantization=marlin for faster inference.") - else: - logger.info( - "The model is convertible to Marlin format. " - "Using Marlin kernel.") - quant_method = "gptq_marlin" - if self.quantization == "marlin": - self.quantization = quant_method + + # Detect which checkpoint is it + for name, method in QUANTIZATION_METHODS.items(): + quantization_override = method.override_quantization_method( + quant_cfg, self.quantization) + if quantization_override: + quant_method = quantization_override + self.quantization = quantization_override + break # Verify quantization configurations. if self.quantization is None: @@ -207,7 +182,8 @@ def _verify_quantization(self) -> None: raise ValueError( f"{self.quantization} quantization is currently not " f"supported in ROCm.") - if (self.quantization not in ["marlin", "gptq_marlin"]): + if (self.quantization + not in ["marlin", "gptq_marlin_24", "gptq_marlin"]): logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 5798bc359dcf2..f938e7d37ec5f 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -10,18 +10,23 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQMarlin24Config) from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, "awq": AWQConfig, + "deepspeedfp": DeepSpeedFPConfig, "fp8": Fp8Config, + # The order of gptq methods is important for config.py iteration over + # override_quantization_method(..) + "marlin": MarlinConfig, + "gptq_marlin_24": GPTQMarlin24Config, + "gptq_marlin": GPTQMarlinConfig, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, - "gptq_marlin": GPTQMarlinConfig, - "marlin": MarlinConfig, - "deepspeedfp": DeepSpeedFPConfig } diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index ff5cf0b2bd61a..e7de283b562a6 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -66,6 +66,17 @@ def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": """Create a config class from the model's quantization config.""" raise NotImplementedError + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + """ + Detects if this quantization method can support a given checkpoint + format by overriding the user specified quantization method -- + this method should only be overwritten by subclasses in exceptional + circumstances + """ + return None + @staticmethod def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: """Get a value from the model's quantization config.""" diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 354bb55d09e24..4374fd98012f6 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -6,11 +6,14 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +logger = init_logger(__name__) + GPTQ_MARLIN_TILE = 16 GPTQ_MARLIN_MIN_THREAD_N = 64 GPTQ_MARLIN_MIN_THREAD_K = 128 @@ -117,6 +120,26 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": is_sym = cls.get_from_keys(config, ["sym"]) return cls(weight_bits, group_size, desc_act, is_sym) + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_marlin_compatible(hf_quant_cfg) + + is_valid_user_quant = (user_quant is None or user_quant == "marlin") + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "gptq": + logger.info("Detected that the model can run with gptq_marlin" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_marlin for" + " faster inference") + return None + def get_quant_method( self, layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]: diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py new file mode 100644 index 0000000000000..1bd6127104654 --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -0,0 +1,280 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +class GPTQMarlin24Config(QuantizationConfig): + """Config class for Marlin24. + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + + if self.weight_bits != 4 and self.weight_bits != 8: + raise ValueError("weight_bits must be 4 or 8. Got = {}".format( + self.weight_bits)) + + if self.group_size != 128 and self.group_size != -1: + raise ValueError( + "Currently, only group size 128 and -1 (channelwise) " + "is supported for Marlin24, but got group_size of " + f"{self.group_size}") + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // self.weight_bits + + # Tile size used by marlin kernels. + self.tile_size = 16 + + # Min out_features dim + self.min_n_threads = 128 + + # Min in_features dim + self.min_k_threads = 128 + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = 16 + + # Permutation length used by the marlin kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return "Marlin24Config(weight_bits={}, group_size={})".format( + self.weight_bits, self.group_size) + + @classmethod + def get_name(cls) -> str: + return "gptq_marlin_24" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits, group_size) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + is_marlin_24_format = ( + hf_quant_cfg.get("checkpoint_format") == "marlin_24") + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "gptq_marlin_24") + + if is_marlin_24_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. " + "Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method( + self, + layer: torch.nn.Module) -> Optional["GPTQMarlin24LinearMethod"]: + if isinstance(layer, LinearBase): + return GPTQMarlin24LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class GPTQMarlin24LinearMethod(LinearMethodBase): + """Linear method for Marlin24. + + Args: + quant_config: The Marlin24 quantization config. + """ + + def __init__(self, quant_config: GPTQMarlin24Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"min_n_threads = {self.quant_config.min_n_threads}.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}.") + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"min_k_threads = {self.quant_config.min_k_threads}.") + if (self.quant_config.group_size != -1 and + input_size_per_partition % self.quant_config.group_size != 0): + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}.") + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError( + "Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.tile_size // 2, + output_size_per_partition * self.quant_config.tile_size // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, + { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + "marlin_tile_size": self.quant_config.tile_size, + }, + ) + + # Meta + meta = Parameter( + torch.empty( + input_size_per_partition // 8 // 2 // 2, + output_size_per_partition * 2, + device="cuda", + dtype=torch.int16, + ), + requires_grad=False, + ) + set_weight_attrs( + meta, + { + "input_dim": 0, + "packed_dim": 1, + "pack_factor": 1, + "output_dim": 1, + "marlin_tile_size": 2, + }, + ) + + # Determine if channelwise or not + input_groups = (1 if self.quant_config.group_size == -1 else + input_size_per_partition // + self.quant_config.group_size) + + scales = Parameter( + torch.empty( + input_groups, + output_size_per_partition, + device="cuda", + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + "input_dim": None if input_groups == 1 else 0, + "output_dim": 1, + }, + ) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_n_threads) * self.quant_config.max_parallel + workspace = Parameter(torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + requires_grad=False) + + layer.register_parameter("B_24", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("B_meta", meta) + set_weight_attrs(meta, extra_weight_attrs) + layer.register_parameter("s", scales) + set_weight_attrs(scales, extra_weight_attrs) + layer.register_parameter("workspace", workspace) + set_weight_attrs(workspace, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B_24 + meta = layer.B_meta + scales = layer.s + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, + workspace, + self.quant_config.weight_bits, + size_m, size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 94aba620ea083..3613c9d9ecf2a 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -4,11 +4,14 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.utils import set_weight_attrs +logger = init_logger(__name__) + class MarlinConfig(QuantizationConfig): """Config class for Marlin. @@ -72,6 +75,25 @@ def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": group_size = cls.get_from_keys(config, ["group_size"]) return cls(group_size) + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_marlin_format: bool + is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin" + or hf_quant_cfg.get("is_marlin_format", False)) + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "marlin") + + if is_marlin_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. Using {} kernel.". + format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + def get_quant_method( self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]: if isinstance(layer, LinearBase): From f09edd8a25d54c48eb804abe391e98d0b85b9ea2 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 16 May 2024 10:02:56 -0700 Subject: [PATCH 061/124] Add JSON output support for benchmark_latency and benchmark_throughput (#4848) --- .buildkite/run-benchmarks.sh | 7 ++++--- benchmarks/benchmark_latency.py | 20 ++++++++++++++++++-- benchmarks/benchmark_throughput.py | 17 +++++++++++++++++ 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/.buildkite/run-benchmarks.sh b/.buildkite/run-benchmarks.sh index 7fbad1c4bd950..1efc96395933f 100644 --- a/.buildkite/run-benchmarks.sh +++ b/.buildkite/run-benchmarks.sh @@ -9,10 +9,10 @@ cd "$(dirname "${BASH_SOURCE[0]}")/.." (which wget && which curl) || (apt-get update && apt-get install -y wget curl) # run python-based benchmarks and upload the result to buildkite -python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt +python3 benchmarks/benchmark_latency.py --output-json latency_results.json 2>&1 | tee benchmark_latency.txt bench_latency_exit_code=$? -python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt +python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --output-json throughput_results.json 2>&1 | tee benchmark_throughput.txt bench_throughput_exit_code=$? # run server-based benchmarks and upload the result to buildkite @@ -74,4 +74,5 @@ if [ $bench_serving_exit_code -ne 0 ]; then exit $bench_serving_exit_code fi -/workspace/buildkite-agent artifact upload openai-*.json +rm ShareGPT_V3_unfiltered_cleaned_split.json +/workspace/buildkite-agent artifact upload "*.json" diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 8f3168c115ae6..f84e3453947c9 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -1,5 +1,6 @@ """Benchmark the latency of processing a single batch of requests.""" import argparse +import json import time from pathlib import Path from typing import Optional @@ -96,6 +97,16 @@ def run_to_completion(profile_dir: Optional[str] = None): for percentage, percentile in zip(percentages, percentiles): print(f'{percentage}% percentile latency: {percentile} seconds') + # Output JSON results if specified + if args.output_json: + results = { + "avg_latency": np.mean(latencies), + "latencies": latencies.tolist(), + "percentiles": dict(zip(percentages, percentiles.tolist())), + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + if __name__ == '__main__': parser = argparse.ArgumentParser( @@ -149,8 +160,8 @@ def run_to_completion(profile_dir: Optional[str] = None): help= 'Data type for kv cache storage. If "auto", will use model data type. ' 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' - 'common inference criteria.') + 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' + 'instead supported for common inference criteria.') parser.add_argument( '--quantization-param-path', type=str, @@ -197,5 +208,10 @@ def run_to_completion(profile_dir: Optional[str] = None): default=None, help='directory to download and load the weights, ' 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the latency results in JSON format.') args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 695d06e7b243d..41f443968c3c4 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -242,6 +242,18 @@ def main(args: argparse.Namespace): print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} tokens/s") + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Benchmark the throughput.") @@ -353,6 +365,11 @@ def main(args: argparse.Namespace): default=None, help='directory to download and load the weights, ' 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model From b5853f99639afd82cb18f131dce7f8c41eda74bd Mon Sep 17 00:00:00 2001 From: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Date: Thu, 16 May 2024 13:46:52 -0400 Subject: [PATCH 062/124] [ROCm][AMD][Bugfix] adding a missing triton autotune config (#4845) --- vllm/attention/ops/triton_flash_attention.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 1147664183ff1..f94211116a746 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -239,6 +239,16 @@ def _attn_fwd_inner( num_stages=1, num_warps=8, ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), triton.Config( { "BLOCK_M": 128, From e08188081be890f72a8d3dda66e7e1ce0a45216c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 16 May 2024 10:59:52 -0700 Subject: [PATCH 063/124] [Core][Distributed] remove graph mode function (#4818) --- tests/distributed/test_custom_all_reduce.py | 5 +- tests/distributed/test_pynccl.py | 4 +- vllm/distributed/communication_op.py | 70 ++++++++++++--------- vllm/worker/model_runner.py | 38 ++++++----- 4 files changed, 63 insertions(+), 54 deletions(-) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 630b25a3b6132..186f9faa6bfb6 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port): for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: - with graph_capture(): + with graph_capture() as graph_capture_context: # use integers so result matches NCCL exactly inp1 = torch.randint(1, 16, (sz, ), @@ -62,7 +62,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port): device=torch.cuda.current_device()) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): + with torch.cuda.graph(graph, + stream=graph_capture_context.stream): for i in range(num_communication): out1 = tensor_model_parallel_all_reduce(inp1) # the input buffer is immediately modified to test diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index a0f7500bf0ee9..529e75fb2c9e3 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -5,7 +5,7 @@ import torch from vllm.distributed.communication_op import ( # noqa - graph_mode, tensor_model_parallel_all_reduce) + graph_capture, tensor_model_parallel_all_reduce) from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, @@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") ensure_model_parallel_initialized(2, 2) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - with graph_mode(): + with graph_capture(): # two tp groups can communicate independently if torch.distributed.get_rank() in [0, 1]: tensor = tensor_model_parallel_all_reduce(tensor) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index f8ee0f9796bcd..937fd4d392713 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,5 +1,6 @@ from collections import namedtuple from contextlib import contextmanager, nullcontext +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -13,45 +14,54 @@ get_tp_pynccl_communicator) -@contextmanager -def graph_mode(): - # In graph mode, we have to be very careful about the collective - # operations. The current status is: - # allreduce \ Mode | Eager | Graph | - # -------------------------------------------- - # custom allreduce | enabled | enabled | - # PyNccl | disabled| enabled | - # torch.distributed | enabled | disabled| - # - # Note that custom allreduce will have a runtime check, if the tensor size - # is too large, it will fallback to the next available option. - # In summary: When using CUDA graph, we use - # either custom all-reduce kernel or pynccl. When not using CUDA - # graph, we use either custom all-reduce kernel or PyTorch NCCL. - # We always prioritize using custom all-reduce kernel but fall back - # to PyTorch or pynccl if it is disabled or not supported. - pynccl_comm = get_tp_pynccl_communicator() - if pynccl_comm is None: - context = nullcontext() - else: - context = pynccl_comm.change_state(enable=True, - stream=torch.cuda.current_stream()) - with context: - yield +@dataclass +class GraphCaptureContext: + stream: torch.cuda.Stream @contextmanager def graph_capture(): """ - `graph_capture` is a context manager which should include the code that + `graph_capture` is a context manager which should surround the code that is capturing the CUDA graph. Its main purpose is to ensure that the some operations will be run after the graph is captured, before the graph - is replayed. + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current CUDA stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. """ + stream = torch.cuda.Stream() + graph_capture_context = GraphCaptureContext(stream) ca_comm = get_tp_ca_communicator() - context = nullcontext() if ca_comm is None else ca_comm.capture() - with context: - yield + maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture() + with torch.cuda.stream(stream), maybe_ca_context: + # In graph mode, we have to be very careful about the collective + # operations. The current status is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | enabled | enabled | + # PyNccl | disabled| enabled | + # torch.distributed | enabled | disabled| + # + # Note that custom allreduce will have a runtime check, if the tensor + # size is too large, it will fallback to the next available option. + # In summary: When using CUDA graph, we use + # either custom all-reduce kernel or pynccl. When not using CUDA + # graph, we use either custom all-reduce kernel or PyTorch NCCL. + # We always prioritize using custom all-reduce kernel but fall back + # to PyTorch or pynccl if it is disabled or not supported. + pynccl_comm = get_tp_pynccl_communicator() + if pynccl_comm is None: + maybe_pynccl_context = nullcontext() + else: + maybe_pynccl_context = pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream()) + with maybe_pynccl_context: + yield graph_capture_context def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 623a0bc32211c..460f98d7a826e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,7 +10,7 @@ ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict -from vllm.distributed.communication_op import graph_capture, graph_mode +from vllm.distributed.communication_op import graph_capture from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -841,7 +841,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] - with graph_capture(): + with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): @@ -877,6 +877,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: kv_caches, attn_metadata, memory_pool=self.graph_memory_pool, + stream=graph_capture_context.stream, ) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[batch_size] = graph_runner @@ -921,15 +922,27 @@ def capture( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - memory_pool, + memory_pool: Optional[Tuple[int, int]], + stream: torch.cuda.Stream, **kwargs, ) -> None: assert self._graph is None # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with graph_mode(): - self.model( + self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + **kwargs, + ) + torch.cuda.synchronize() + + # Capture the graph. + self._graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): + hidden_states = self.model( input_ids, positions, kv_caches, @@ -938,21 +951,6 @@ def capture( ) torch.cuda.synchronize() - # Capture the graph. - # NOTE(woosuk): Python 3.8 does not support multi-line with statements. - # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement - self._graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117 - with graph_mode(): - hidden_states = self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - **kwargs, - ) - torch.cuda.synchronize() - # Save the input and output buffers. self.input_buffers = { "input_ids": input_ids, From 10fa9eea21ae757d17c1369afa6172598db3be92 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 16 May 2024 11:07:41 -0700 Subject: [PATCH 064/124] [Misc] remove old comments (#4866) --- vllm/worker/model_runner.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 460f98d7a826e..cd7af25654b52 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -887,16 +887,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # This usually takes < 10 seconds. logger.info("Graph capturing finished in %.0f secs.", elapsed_time) - def __del__(self) -> None: - # Delete the CUDA graphs before deleting the pynccl communicator. - # NOTE(woosuk): This is necessary because otherwise deadlocks can - # happen. - # FIXME(woosuk): This is a bit hacky. Find a more robust solution. - # TODO(youkaichao): when we get enough user feedback that pynccl is - # more stable than cupy, we can remove this, e.g. in v0.4.1. - self.graph_runners.clear() - self.pynccl_backend = None - @property def vocab_size(self) -> int: return self.model_config.get_vocab_size() From 8435b207af398cb6cec961f8ac8e1d8bb5164b3e Mon Sep 17 00:00:00 2001 From: Silencio <19430328+Silencioo@users.noreply.github.com> Date: Fri, 17 May 2024 02:16:09 +0800 Subject: [PATCH 065/124] [Kernel] Add punica dimension for Qwen1.5-32B LoRA (#4850) Co-authored-by: Silencio --- csrc/punica/bgmv/bgmv_config.h | 2 ++ tests/lora/test_punica.py | 1 + 2 files changed, 3 insertions(+) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 19c058cacfbc4..98ac8de779e13 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -53,6 +53,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 22016) \ f(in_T, out_T, W_T, narrow, 24576) \ f(in_T, out_T, W_T, narrow, 27392) \ + f(in_T, out_T, W_T, narrow, 27648) \ f(in_T, out_T, W_T, narrow, 28672) \ f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32256) \ @@ -121,6 +122,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 22016, narrow) \ f(in_T, out_T, W_T, 24576, narrow) \ f(in_T, out_T, W_T, 27392, narrow) \ + f(in_T, out_T, W_T, 27648, narrow) \ f(in_T, out_T, W_T, 28672, narrow) \ f(in_T, out_T, W_T, 32000, narrow) \ f(in_T, out_T, W_T, 32256, narrow) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index fd2a1b75f460c..193e3906997c4 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -79,6 +79,7 @@ def _lora_ref_impl( 22016, 24576, 27392, + 27648, 32000, 32256, 32512, From 2060e93659f1f63a3d2a76aee61559ccb1fe732e Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 16 May 2024 18:32:50 -0400 Subject: [PATCH 066/124] [Kernel] Add w8a8 CUTLASS kernels (#4749) --- CMakeLists.txt | 27 +- csrc/ops.h | 8 + csrc/pybind.cpp | 1 + csrc/quantization/cutlass_w8a8/common.hpp | 12 + .../cutlass_visitor_2x_broadcast_epilogue.hpp | 340 ++++++++++++++++++ .../cutlass_w8a8/scaled_mm_dq_c2x.cu | 296 +++++++++++++++ .../cutlass_w8a8/scaled_mm_dq_c3x.cu | 240 +++++++++++++ .../cutlass_w8a8/scaled_mm_dq_entry.cu | 65 ++++ tests/kernels/test_cutlass.py | 192 ++++++++++ vllm/_custom_ops.py | 18 +- 10 files changed, 1197 insertions(+), 2 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/common.hpp create mode 100644 csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu create mode 100644 tests/kernels/test_cutlass.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 2051d7560be25..35846fd1cfa99 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -173,6 +173,16 @@ set(VLLM_EXT_SRC "csrc/pybind.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") + include(FetchContent) + SET(CUTLASS_ENABLE_HEADERS_ONLY=ON) + FetchContent_Declare( + cutlass + GIT_REPOSITORY https://github.com/nvidia/cutlass.git + # CUTLASS 3.5.0 + GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc + ) + FetchContent_MakeAvailable(cutlass) + list(APPEND VLLM_EXT_SRC "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" @@ -180,7 +190,21 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" - "csrc/custom_all_reduce.cu") + "csrc/custom_all_reduce.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu") + + # + # The CUTLASS kernels for Hopper require sm90a to be enabled. + # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a. + # That adds an extra 17MB to compiled binary, so instead we selectively enable it. + set_source_files_properties( + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu" + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") + endif() define_gpu_extension_target( @@ -190,6 +214,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} WITH_SOABI) # diff --git a/csrc/ops.h b/csrc/ops.h index ef37131c962f8..8c2c2ae6e1f5a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -155,6 +155,14 @@ torch::Tensor gptq_marlin_repack( int64_t size_k, int64_t size_n, int64_t num_bits); + +int cutlass_scaled_mm_dq( + torch::Tensor& out, + torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + #endif void squeezellm_gemm( diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 0339eba70c013..f5b4865506568 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -71,6 +71,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); + ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization."); #endif ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp new file mode 100644 index 0000000000000..999b7b251ab33 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/common.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include "cutlass/cutlass.h" + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + cutlassGetStatusString(status)) \ + } diff --git a/csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp b/csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp new file mode 100644 index 0000000000000..ddbee15e54ab6 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/visitor_load.hpp from +// https://github.com/NVIDIA/cutlass It's beem modified to support either +// row/column or scalar broadcasting, like is already supported in CUTLASS 3.x. +// Important because this saves us a factor 4x on the number of kernels +// compiled. +// +#pragma once + +// clang-format off + +#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" +#include "cute/tensor.hpp" + +// clang-format on + +namespace cutlass::epilogue::threadblock { + +using namespace cute; +using namespace detail; + +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrScalarBroadcast { + + struct Arguments { + Element const* ptr_row = nullptr; + Element null_default = Element(0); + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->ptr_row) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load(dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are loading from a scalar and broadcasting + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = params_ptr->null_default; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if(get<1>(coord_v(i)) < n) + { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + class ThreadMap, + class Element, + class StrideMNL = Stride<_1,_0,_0> +> +struct VisitorColOrScalarBroadcast { + + struct Arguments { + Element const* ptr_col = nullptr; + Element null_default = Element(0); + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage { }; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gCol, + RTensor&& tC_rCol, + CTensor&& tC_cCol, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gCol(cute::forward(tC_gCol)), + tC_rCol(cute::forward(tC_rCol)), + tC_cCol(cute::forward(tC_cCol)), + m(get<0>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gCol; + RTensor tC_rCol; + CTensor tC_cCol; + Params const* params_ptr; + int m; + + // This function is modified from VisitorColBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rCol); + + Tensor pred = make_tensor(shape(tC_gCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tC_cCol(i)) < m; + } + + if (params_ptr->ptr_col) { + // In this case we are loading from a column vector and broadcasting + copy_if(pred, tC_gCol, tC_rCol); + } else { + // In this case we are loading from a scalar and broadcasting + auto dst_v = filter(tC_rCol); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(dst_v); ++i) { + if(pred(i)){ + dst_v(i) = params_ptr->null_default; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Array frg_col; + frg_col.fill(tC_rCol(row_idx,iter_idx)); + return frg_col; + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mCol = make_tensor( + make_gmem_ptr(params_ptr->ptr_col), + problem_shape, + params_ptr->dCol); + + // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER + Tensor tC_gCol = group_modes<1,4>( + ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + Tensor tC_rCol = make_tensor_like(tC_gCol); + + // Generate the pred tensor + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tC_cCol = group_modes<1,4>( + ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + + return Callbacks< + decltype(tC_gCol), decltype(tC_rCol), + decltype(tC_cCol), ProblemShape>( + cute::move(tC_gCol), + cute::move(tC_rCol), + cute::move(tC_cCol), + problem_shape, + params_ptr + ); + } +}; + +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu new file mode 100644 index 0000000000000..3ec454f78c654 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu @@ -0,0 +1,296 @@ +#include +#include + +// clang-format will break include orders +// clang-format off +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/util/device_memory.h" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm_coord.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" + +#include "cutlass_visitor_2x_broadcast_epilogue.hpp" +#include "common.hpp" +// clang-format on + +using namespace cute; + +/* + This defines a quantized GEMM operation with dequantized output, similar to + torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for + NVIDIA GPUs with SM versions prior to sm90 (Hopper). + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ + +namespace { + +template +struct cutlass_2x_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using Operator = + typename std::conditional, + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type; + + using OutputTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + TileShape, WarpShape, float, 4, 1 /* epilogue stages */ + >; + + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< + OutputTileThreadMap, float, Stride, Int<0>, Int<0>>>; + + using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< + OutputTileThreadMap, float, Stride, Int<1>, Int<0>>>; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::threadblock::Sm80EVT; + + using D = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, + Stride, Int<0>>>; + + using EVTD = cutlass::epilogue::threadblock::Sm80EVT; + + // clang-format off + using RowMajor = typename cutlass::layout::RowMajor; + using ColumnMajor = typename cutlass::layout::ColumnMajor; + using KernelType = + typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16, + ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16, + float, cutlass::layout::RowMajor, 4, + ElementAcc, float, cutlass::arch::OpClassTensorOp, + Arch, + TileShape, WarpShape, InstructionShape, + EVTD, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, + MainLoopStages, Operator, + 1 /* epilogue stages */ + >::GemmKernel; + // clang-format on + + using Op = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + cutlass::gemm::GemmCoord problem_size{m, n, k}; + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideC = Stride, Int<0>>; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); + + auto a_scales_ptr = a_scales.data_ptr(); + auto b_scales_ptr = b_scales.data_ptr(); + + // If A and B are quantized per-tensor, then these scale tensors are scalars, + // and they are passed in via the second argument. + using ScaleAArgs = typename Gemm::ScaleA::Arguments; + ScaleAArgs a_args = a_scales.numel() == 1 + ? ScaleAArgs{nullptr, a_scales.item(), {}} + : ScaleAArgs{a_scales.data_ptr(), {}, {}}; + + using ScaleBArgs = typename Gemm::ScaleB::Arguments; + ScaleBArgs b_args = b_scales.numel() == 1 + ? ScaleBArgs{nullptr, b_scales.item(), {}} + : ScaleBArgs{b_scales.data_ptr(), {}, {}}; + + typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args}; + + typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args, + evt0_compute_args}; + typename Gemm::D::Arguments d_args{c_ptr, c_stride}; + + typename Gemm::EVTD::Arguments epilogue_args{ + evt1_compute_args, + d_args, + }; + + typename Gemm::Op::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode + problem_size, // problem size + 1, // batch count + epilogue_args, + a_ptr, + b_ptr, + nullptr, + nullptr, + 0, + 0, + 0, + 0, + lda, + ldb, + ldc, + ldc}; + + // Launch the CUTLASS GEMM kernel. + typename Gemm::Op gemm_op; + size_t workspace_size = gemm_op.get_workspace_size(args); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(args)); + cutlass::Status status = gemm_op(args, workspace.get()); + CUTLASS_CHECK(status); +} + +} // namespace + +void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>(out, a, b, a_scales, + b_scales); + } +} + +void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>(out, a, b, a_scales, + b_scales); + } +} + +void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + if (a.dtype() == torch::kInt8) { + TORCH_CHECK(b.dtype() == torch::kInt8); + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } else { + assert(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } + } else { + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, + b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, + b_scales); + } + } +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu new file mode 100644 index 0000000000000..37b096de23e3b --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -0,0 +1,240 @@ +#include + +#include +#include +#include + +// clang-format will break include orders +// clang-format off +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "common.hpp" +// clang-format on + +using namespace cute; + +/* + This defines a quantized GEMM operation with dequantized output, similar to + torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for + NVIDIA GPUs with sm90a (Hopper) or later. + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ + +namespace { + +template +struct cutlass_3x_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, + ElementD, EpilogueSchedule>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, + Stride, Int<0>, Int<0>>>; + + using ScaleBDescriptor = + cutlass::epilogue::collective::detail::RowBroadcastDescriptor< + EpilogueDescriptor, float>; + + using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast< + ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, + typename ScaleBDescriptor::Element, Stride, Int<1>, Int<0>>>; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::fusion::Sm90EVT; + + using StrideD = Stride, Int<0>>; + using ElementC = void; + using StrideC = StrideD; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, + EpilogueSchedule, EVTCompute1>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + // clang-format off + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementAB, cutlass::layout::RowMajor, 16, + ElementAB, cutlass::layout::ColumnMajor, 16, + ElementAcc, TileShape, ClusterShape, + Stages, + KernelSchedule>::CollectiveOp; + // clang-format on + + using KernelType = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, CollectiveMainloop, CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, Int<0>>; + using StrideB = Stride, Int<0>>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, Int<0>{}}; + StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + using GemmKernel = typename Gemm::GemmKernel; + typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + + typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, mainloop_args, epilogue_args}; + + using ScaleA_Args = typename Gemm::ScaleA::Arguments; + using ScaleB_Args = typename Gemm::ScaleB::Arguments; + ScaleA_Args a_args = a_scales.numel() == 1 + ? ScaleA_Args{nullptr, a_scales.item(), {}} + : ScaleA_Args{a_scales.data_ptr(), {}, {}}; + + ScaleB_Args b_args = b_scales.numel() == 1 + ? ScaleB_Args{nullptr, b_scales.item(), {}} + : ScaleB_Args{b_scales.data_ptr(), {}, {}}; + + args.epilogue.thread = {a_args, {b_args}}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + TORCH_CHECK(workspace_size == 0); + + cutlass::Status status = gemm_op.run(args); + CUTLASS_CHECK(status); +} +} // namespace + +void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + if (a.dtype() == torch::kInt8) { + TORCH_CHECK(b.dtype() == torch::kInt8); + + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = + typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } + } else { + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = + typename cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } + } +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu new file mode 100644 index 0000000000000..a4e696d4a3322 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu @@ -0,0 +1,65 @@ +#include +#include +#include + +void cutlass_scaled_mm_dq_sm75(torch::Tensor &c, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + +void cutlass_scaled_mm_dq_sm80(torch::Tensor &c, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + +void cutlass_scaled_mm_dq_sm89(torch::Tensor &c, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + +void cutlass_scaled_mm_dq_sm90(torch::Tensor &c, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + +void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a, + torch::Tensor const &b, torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + int32_t major_capability; + int32_t minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + 0); + int32_t version_num = major_capability * 10 + minor_capability; + + // Checks for conformality + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); + + if (version_num >= 90) { + // Hopper + cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales); + } else if (version_num == 89) { + // Ada Lovelace + cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales); + } else if (version_num >= 80) { + // Ampere + cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales); + } else { + // Turing + TORCH_CHECK(version_num >= 75); + cutlass_scaled_mm_dq_sm75(c, a, b, a_scales, b_scales); + } +} diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py new file mode 100644 index 0000000000000..fdfd1dee29ce6 --- /dev/null +++ b/tests/kernels/test_cutlass.py @@ -0,0 +1,192 @@ +"""Tests for cutlass kernels + +Run `pytest tests/kernels/test_cutlass.py`. +""" +from typing import Type + +import pytest +import torch + +from vllm import _custom_ops as ops + +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] + + +def to_fp8(tensor: torch.tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.tensor): + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def cutlass_fp8_gemm_helper(m: int, + n: int, + k: int, + per_token_act_quant: bool, + per_out_channel_weight_quant: bool, + out_dtype: Type[torch.dtype] = torch.bfloat16, + device: str = "cuda"): + # Test for a cutlass kernel with per-token activation quantization + # and per-output channel weight quantization. + a = to_fp8(torch.randn((m, k), device=device)) + b = to_fp8(torch.randn((n, k), device=device).t()) + + m_a_scales = m if per_token_act_quant else 1 + n_b_scales = n if per_out_channel_weight_quant else 1 + + scale_a = (torch.randn( + (m_a_scales, 1), device=device, dtype=torch.float32) / 10) + scale_b = (torch.randn( + (1, n_b_scales), device=device, dtype=torch.float32) / 10) + + out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * b.to(dtype=torch.float32)).to(out_dtype) + + assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1) + + +def cutlass_int8_gemm_helper(m: int, + n: int, + k: int, + per_token_act_quant: bool, + per_out_channel_weight_quant: bool, + out_dtype: Type[torch.dtype] = torch.bfloat16, + device: str = "cuda"): + # Test for a cutlass kernel with per-token activation quantization + # and per-output channel weight quantization. + a = to_int8(torch.randn((m, k), device=device) * 5) + b = to_int8(torch.randn((n, k), device=device).t() * 5) + + m_a_scales = m if per_token_act_quant else 1 + n_b_scales = n if per_out_channel_weight_quant else 1 + + scale_a = (torch.randn( + (m_a_scales, 1), device=device, dtype=torch.float32) / 10) + scale_b = (torch.randn( + (1, n_b_scales), device=device, dtype=torch.float32) / 10) + + out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * + b.to(dtype=torch.float32)).to(dtype=out_dtype) + + assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) + + +@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 496, 1024]) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, + per_out_ch: bool): + cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch) + + +@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 496, 1024]) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool, + per_out_ch: bool): + cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) +def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, + out_dtype: Type[torch.dtype]): + cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + out_dtype) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, + out_dtype: Type[torch.dtype]): + cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + out_dtype) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool, + device: str): + cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + torch.bfloat16, device) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool, + device: str): + cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + torch.bfloat16, device) + + +# For the following two tests: +# N and K correspond to the size of the weight matrix and likely to be multiples +# of a large power of two. In any case, the kernel will have a naive fallback +# when N and K are not divisible by 16. But M is the number of tokens and the +# kernel must handle any M thrown at it. +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool): + for nk in range(32, 128, 32): + for m in range(1, 128): + cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool): + for nk in range(32, 128, 32): + for m in range(1, 128): + cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch) + + +# Test working with a subset of A and B +def test_cutlass_subset(): + big_m, big_n, big_k = 1024, 1024, 1024 + m, n, k = 512, 512, 512 + + whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5) + whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5) + a = whole_a[0:m, 0:k] + b = whole_b[0:k, 0:n] + + scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 + scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 + + out = ops.cutlass_scaled_mm_dq(a, + b, + scale_a, + scale_b, + out_dtype=torch.bfloat16) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * + b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) + + assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 95baa84262658..9e7d0d96bf004 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, Type import torch @@ -163,6 +163,22 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, size_k) +# cutlass +def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, + a_scales: torch.Tensor, b_scales: torch.Tensor, + out_dtype: Type[torch.dtype]) -> torch.Tensor: + assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales) + + return out + + # aqlm def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, From 9a31a817a85ac4249bf82dd8b6f90ef6b8e81fef Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 16 May 2024 15:42:29 -0700 Subject: [PATCH 067/124] [Bugfix] Fix FP8 KV cache support (#4869) --- vllm/attention/backends/flash_attn.py | 10 +++++----- vllm/attention/backends/flashinfer.py | 10 +++++----- vllm/attention/backends/rocm_flash_attn.py | 10 +++++----- vllm/attention/backends/torch_sdpa.py | 10 +++++----- vllm/attention/backends/xformers.py | 10 +++++----- vllm/attention/layer.py | 2 +- 6 files changed, 26 insertions(+), 26 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 5d1f65819ed4e..856f399741375 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -200,15 +200,15 @@ def __init__( num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - kv_cache_dtype: str = "auto", + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 5f9fd586fb70e..7210fefbd8162 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -164,15 +164,15 @@ def __init__( num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - kv_cache_dtype: str = "auto", + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 1a94dc3596358..bb828d6fc04fe 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -197,15 +197,15 @@ def __init__( num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - kv_cache_dtype: str = "auto", + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index a3f72b9c94566..a19c97e1e0e35 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -96,15 +96,15 @@ def __init__( num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - kv_cache_dtype: str = "auto", + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index fc46af054de4f..96169da6cf92c 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -208,15 +208,15 @@ def __init__( num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - kv_cache_dtype: str = "auto", + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 126692d8c9b40..4299726bdca4b 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -48,7 +48,7 @@ def __init__( block_size) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window) + alibi_slopes, sliding_window, kv_cache_dtype) def forward( self, From 8e7fb5d43ae74e0a75a7da940a63c7891208d268 Mon Sep 17 00:00:00 2001 From: Kante Yin Date: Fri, 17 May 2024 07:37:29 +0800 Subject: [PATCH 068/124] Support to serve vLLM on Kubernetes with LWS (#4829) Signed-off-by: kerthcet --- docs/source/serving/deploying_with_lws.rst | 12 ++++++++++++ docs/source/serving/integrations.rst | 1 + 2 files changed, 13 insertions(+) create mode 100644 docs/source/serving/deploying_with_lws.rst diff --git a/docs/source/serving/deploying_with_lws.rst b/docs/source/serving/deploying_with_lws.rst new file mode 100644 index 0000000000000..b63a432dde0d5 --- /dev/null +++ b/docs/source/serving/deploying_with_lws.rst @@ -0,0 +1,12 @@ +.. _deploying_with_lws: + +Deploying with LWS +============================ + +LeaderWorkerSet (LWS) is a Kubernetes API that aims to address common deployment patterns of AI/ML inference workloads. +A major use case is for multi-host/multi-node distributed inference. + +vLLM can be deployed with `LWS `_ on Kubernetes for distributed model serving. + +Please see `this guide `_ for more details on +deploying vLLM on Kubernetes using LWS. diff --git a/docs/source/serving/integrations.rst b/docs/source/serving/integrations.rst index 93872397913e3..2066e80b03298 100644 --- a/docs/source/serving/integrations.rst +++ b/docs/source/serving/integrations.rst @@ -8,4 +8,5 @@ Integrations deploying_with_kserve deploying_with_triton deploying_with_bentoml + deploying_with_lws serving_with_langchain From 0150a1063029f0238c25bc5a2ea0943b9650d522 Mon Sep 17 00:00:00 2001 From: bofeng huang Date: Fri, 17 May 2024 03:47:22 +0200 Subject: [PATCH 069/124] [Frontend] OpenAI API server: Do not add bos token by default when encoding (#4688) --- vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_engine.py | 32 +++++++++++++++-------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c86e41c601be0..7e179362eef8a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -158,7 +158,7 @@ async def create_chat_completion( try: # Tokenize/detokenize depending on prompt format (string/token list) prompt_ids, prompt_text = self._validate_prompt_and_tokenize( - request, prompt=prompt) + request, prompt=prompt, add_special_tokens=False) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) decoding_config = await self.engine.get_decoding_config() diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 58a1c2f7e73fe..db3fc85decd70 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,7 +1,7 @@ import json from dataclasses import dataclass from http import HTTPStatus -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from pydantic import Field from typing_extensions import Annotated @@ -165,13 +165,14 @@ def _maybe_get_lora( raise ValueError(f"The model `{request.model}` does not exist.") def _validate_prompt_and_tokenize( - self, - request: Union[ChatCompletionRequest, CompletionRequest, - EmbeddingRequest], - prompt: Optional[str] = None, - prompt_ids: Optional[List[int]] = None, - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None - ) -> Tuple[List[int], str]: + self, + request: Union[ChatCompletionRequest, CompletionRequest, + EmbeddingRequest], + prompt: Optional[str] = None, + prompt_ids: Optional[List[int]] = None, + truncate_prompt_tokens: Optional[Annotated[int, + Field(ge=1)]] = None, + add_special_tokens: bool = True) -> Tuple[List[int], str]: if not (prompt or prompt_ids): raise ValueError("Either prompt or prompt_ids should be provided.") if (prompt and prompt_ids): @@ -179,10 +180,19 @@ def _validate_prompt_and_tokenize( "Only one of prompt or prompt_ids should be provided.") if prompt_ids is None: - tokenizer_kwargs = {} if truncate_prompt_tokens is None else { - "truncation": True, - "max_length": truncate_prompt_tokens, + # When using OpenAIServingChat for chat completions, the + # special tokens (e.g., BOS) have already been added by the + # chat template. Therefore, we do not need to add them again. + # Set add_special_tokens to False to avoid adding the BOS tokens + # again. + tokenizer_kwargs: Dict[str, Any] = { + "add_special_tokens": add_special_tokens } + if truncate_prompt_tokens is not None: + tokenizer_kwargs.update({ + "truncation": True, + "max_length": truncate_prompt_tokens, + }) input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids elif truncate_prompt_tokens is not None: input_ids = prompt_ids[-truncate_prompt_tokens:] From 26148120b3c05704409a425d017f0a51fca3b7cc Mon Sep 17 00:00:00 2001 From: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Date: Thu, 16 May 2024 22:58:25 -0500 Subject: [PATCH 070/124] [Build/CI] Extending the set of AMD tests with Regression, Basic Correctness, Distributed, Engine, Llava Tests (#4797) --- .buildkite/run-amd-test.sh | 11 ++++++----- .buildkite/test-pipeline.yaml | 18 +++++++++++++++--- .buildkite/test-template.j2 | 3 +-- tests/engine/test_stop_reason.py | 6 +++++- vllm/config.py | 10 +--------- 5 files changed, 28 insertions(+), 20 deletions(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index ce508e4748aba..7452423479521 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -1,4 +1,4 @@ -# This script build the ROCm docker image and runs test inside it. +# This script runs test inside the corresponding ROCm docker container. set -ex # Print ROCm version @@ -19,15 +19,16 @@ done echo "--- Building container" sha=$(git rev-parse --short HEAD) -container_name=rocm_${sha} +image_name=rocm_${sha} +container_name=rocm_${sha}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo) docker build \ - -t ${container_name} \ + -t ${image_name} \ -f Dockerfile.rocm \ --progress plain \ . remove_docker_container() { - docker rm -f ${container_name} || docker image rm -f ${container_name} || true + docker rm -f ${container_name} || docker image rm -f ${image_name} || true } trap remove_docker_container EXIT @@ -39,6 +40,6 @@ docker run \ --rm \ -e HF_TOKEN \ --name ${container_name} \ - ${container_name} \ + ${image_name} \ /bin/bash -c "${@}" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index aa74672f4bf67..d9819881fbbfc 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -5,13 +5,16 @@ steps: - label: Regression Test + mirror_hardwares: [amd] command: pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional - label: AsyncEngine Test + #mirror_hardwares: [amd] command: pytest -v -s async_engine - label: Basic Correctness Test + mirror_hardwares: [amd] commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py @@ -24,14 +27,15 @@ steps: command: pytest -v -s core - label: Distributed Comm Ops Test + #mirror_hardwares: [amd] command: pytest -v -s distributed/test_comm_ops.py working_dir: "/vllm-workspace/tests" num_gpus: 2 - label: Distributed Tests + mirror_hardwares: [amd] working_dir: "/vllm-workspace/tests" num_gpus: 2 - mirror_hardwares: [amd] commands: - pytest -v -s distributed/test_pynccl_library.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py @@ -45,16 +49,18 @@ steps: - pytest -v -s spec_decode/e2e/test_integration_dist.py - label: Distributed Tests (Multiple Groups) + #mirror_hardwares: [amd] working_dir: "/vllm-workspace/tests" num_gpus: 4 commands: - pytest -v -s distributed/test_pynccl.py - label: Engine Test - #mirror_hardwares: [amd] + mirror_hardwares: [amd] command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py - label: Entrypoints Test + #mirror_hardwares: [amd] commands: # these tests have to be separated, because each one will allocate all posible GPU memory - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py @@ -74,6 +80,7 @@ steps: - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - label: Kernels Test %N + #mirror_hardwares: [amd] command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 @@ -84,7 +91,7 @@ steps: - pytest -v -s models --ignore=models/test_llava.py - label: Llava Test - #mirror_hardwares: [amd] + mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - pytest -v -s models/test_llava.py @@ -95,6 +102,7 @@ steps: - pytest -v -s prefix_caching - label: Samplers Test + #mirror_hardwares: [amd] command: pytest -v -s samplers - label: LogitsProcessor Test @@ -110,16 +118,20 @@ steps: command: pytest -v -s spec_decode - label: LoRA Test %N + #mirror_hardwares: [amd] command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 - label: Tensorizer Test + #mirror_hardwares: [amd] command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader - label: Metrics Test + mirror_hardwares: [amd] command: pytest -v -s metrics - label: Quantization Test + #mirror_hardwares: [amd] command: pytest -v -s quantization - label: Benchmarks diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 174c756ae74a3..265833e2ccf6e 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -3,9 +3,8 @@ {% set default_working_dir = "/vllm-workspace/tests" %} steps: - - label: ":docker: build image" - commands: + commands: - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." - "docker push {{ docker_image }}" env: diff --git a/tests/engine/test_stop_reason.py b/tests/engine/test_stop_reason.py index b2f521a8ae4ce..7b886507c04f2 100644 --- a/tests/engine/test_stop_reason.py +++ b/tests/engine/test_stop_reason.py @@ -32,6 +32,7 @@ def test_stop_reason(vllm_model, example_prompts): # test stop token outputs = llm.generate(example_prompts, sampling_params=SamplingParams( + ignore_eos=True, seed=SEED, max_tokens=MAX_TOKENS, stop_token_ids=[stop_token_id])) @@ -43,7 +44,10 @@ def test_stop_reason(vllm_model, example_prompts): # test stop string outputs = llm.generate(example_prompts, sampling_params=SamplingParams( - seed=SEED, max_tokens=MAX_TOKENS, stop=".")) + ignore_eos=True, + seed=SEED, + max_tokens=MAX_TOKENS, + stop=".")) for output in outputs: output = output.outputs[0] assert output.finish_reason == "stop" diff --git a/vllm/config.py b/vllm/config.py index 77ce8c318d8f1..6be8f353aa389 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1060,7 +1060,7 @@ def get_image_input_enum_type( "bfloat16": torch.bfloat16, } -_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"] +_ROCM_NOT_SUPPORTED_DTYPE: List[str] = [] # def _get_and_verify_dtype( @@ -1092,14 +1092,6 @@ def _get_and_verify_dtype( else: raise ValueError(f"Unknown dtype: {dtype}") - if is_hip() and torch_dtype == torch.float32: - rocm_supported_dtypes = [ - k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() - if (k not in _ROCM_NOT_SUPPORTED_DTYPE) - ] - raise ValueError(f"dtype '{dtype}' is not supported in ROCm. " - f"Supported dtypes are {rocm_supported_dtypes}") - # Verify the dtype. if torch_dtype != config_dtype: if torch_dtype == torch.float32: From 33e0823de583819f39e88c39ea3f7dd4e07c3990 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 17 May 2024 17:43:34 +0800 Subject: [PATCH 071/124] [Bugfix] fix rope error when load models with different dtypes (#4835) --- tests/kernels/test_pos_encoding.py | 44 ++++++++++++++++++- .../model_executor/layers/rotary_embedding.py | 33 +++++++++----- 2 files changed, 64 insertions(+), 13 deletions(-) diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 18c8e351aa778..076730cdbae0d 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -1,4 +1,4 @@ -from itertools import accumulate +from itertools import accumulate, product from typing import List, Optional import pytest @@ -207,3 +207,45 @@ def test_batched_rotary_embedding_multi_lora( ref_key, atol=get_default_atol(out_key), rtol=get_default_rtol(out_key)) + + +@torch.inference_mode() +def test_rope_module_cache(): + MAX_POSITIONS = [123, 1234] + BASES = [10000, 1000000] + ROPE_SCALINGS = [ + None, { + "type": "linear", + "factor": (1, ) + }, { + "type": "dynamic", + "factor": 1 + } + ] + settings = [ + HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE, + ROPE_SCALINGS, DTYPES + ] + rope_setting_id_map = {} + for setting in product(*settings): + head_size, rotary_dim, max_position, base, \ + is_neox_stype, rope_scaling, dtype = setting + if rotary_dim is None: + rotary_dim = head_size + rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_stype, rope_scaling, dtype) + # different settings cannot share the same rope module + assert id(rope) not in rope_setting_id_map.values() + assert all(x.dtype == dtype for x in rope.buffers()) + assert all(x.dtype == dtype for x in rope.parameters()) + rope_setting_id_map[str(setting)] = id(rope) + + for setting in product(*settings): + head_size, rotary_dim, max_position, base, \ + is_neox_stype, rope_scaling, dtype = setting + if rotary_dim is None: + rotary_dim = head_size + rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_stype, rope_scaling, dtype) + # check if cache take effect + assert id(rope) == rope_setting_id_map[str(setting)] diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index f41e0f30a4e4b..4758ca9660083 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -53,6 +53,7 @@ def __init__( max_position_embeddings: int, base: int, is_neox_style: bool, + dtype: torch.dtype, ) -> None: super().__init__() self.head_size = head_size @@ -62,7 +63,7 @@ def __init__( self.is_neox_style = is_neox_style cache = self._compute_cos_sin_cache() - cache = cache.to(torch.get_default_dtype()) + cache = cache.to(dtype) self.register_buffer("cos_sin_cache", cache, persistent=False) def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: @@ -178,12 +179,13 @@ def __init__( base: int, is_neox_style: bool, scaling_factors: Union[List[float], float], + dtype: torch.dtype, ) -> None: if isinstance(scaling_factors, float): scaling_factors = [scaling_factors] self.scaling_factors = scaling_factors super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style) + is_neox_style, dtype) def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.base) @@ -219,10 +221,11 @@ def __init__( base: int, is_neox_style: bool, scaling_factor: float, + dtype: torch.dtype, ) -> None: self.scaling_factor = scaling_factor super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style) + is_neox_style, dtype) def _compute_cos_sin_cache(self) -> torch.Tensor: # NOTE(woosuk): self.max_position_embeddings is the original @@ -299,6 +302,7 @@ def __init__( base: int, is_neox_style: bool, scaling_factor: float, + dtype: torch.dtype, *, extrapolation_factor: float = 1, attn_factor: float = 1, @@ -314,7 +318,7 @@ def __init__( self.mscale = float( _yarn_get_mscale(self.scaling_factor) * attn_factor) super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style) + is_neox_style, dtype) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: pos_freqs = self.base**( @@ -359,6 +363,7 @@ def __init__( original_max_position_embeddings: int, base: int, is_neox_style: bool, + dtype: torch.dtype, short_factor: List[float], long_factor: List[float], short_mscale: float = 1.1, @@ -385,14 +390,14 @@ def __init__( short_cache = self._compute_cos_sin_cache( original_max_position_embeddings, short_factor, short_mscale) - short_cache = short_cache.to(torch.get_default_dtype()) + short_cache = short_cache.to(dtype) self.register_buffer("short_cos_sin_cache", short_cache, persistent=False) long_cache = self._compute_cos_sin_cache(max_position_embeddings, long_factor, long_mscale) - long_cache = long_cache.to(torch.get_default_dtype()) + long_cache = long_cache.to(dtype) self.register_buffer("long_cos_sin_cache", long_cache, persistent=False) @@ -463,7 +468,10 @@ def get_rope( base: int, is_neox_style: bool = True, rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, ) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() if rope_scaling is not None: # Transforms every value that is a list into a tuple for caching calls rope_scaling_tuple = { @@ -474,12 +482,12 @@ def get_rope( else: rope_scaling_args = None key = (head_size, rotary_dim, max_position, base, is_neox_style, - rope_scaling_args) + rope_scaling_args, dtype) if key in _ROPE_DICT: return _ROPE_DICT[key] if rope_scaling is None: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style) + is_neox_style, dtype) else: scaling_type = rope_scaling["type"] if scaling_type != "su": @@ -488,11 +496,11 @@ def get_rope( rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, - scaling_factor) + scaling_factor, dtype) elif scaling_type == "dynamic": rotary_emb = DynamicNTKScalingRotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, - scaling_factor) + scaling_factor, dtype) elif scaling_type == "yarn": original_max_position = rope_scaling[ "original_max_position_embeddings"] @@ -505,7 +513,7 @@ def get_rope( rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, original_max_position, base, is_neox_style, - scaling_factor, + scaling_factor, dtype, **extra_kwargs) elif scaling_type == "su": short_factor = rope_scaling["short_factor"] @@ -519,7 +527,8 @@ def get_rope( } rotary_emb = Phi3SuScaledRotaryEmbedding( head_size, rotary_dim, max_position, original_max_position, - base, is_neox_style, short_factor, long_factor, **extra_kwargs) + base, is_neox_style, dtype, short_factor, long_factor, + **extra_kwargs) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb From 48d5985a088c6e13e9ad9b0c7a0ce846e30b529f Mon Sep 17 00:00:00 2001 From: eigenLiu <33959526+eigen2017@users.noreply.github.com> Date: Sat, 18 May 2024 00:43:19 +0800 Subject: [PATCH 072/124] Sync huggingface modifications of qwen Moe model (#4774) --- vllm/model_executor/models/qwen2_moe.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 2a3b0173adf8b..a0d3b0406ef4a 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -283,8 +283,9 @@ def __init__( cache_config=cache_config, quant_config=quant_config, ) - if (config.num_experts is not None - and (layer_idx + 1) % config.decoder_sparse_step == 0): + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): self.mlp = Qwen2MoeSparseMoeBlock(config=config, quant_config=quant_config) else: @@ -439,6 +440,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if (("mlp.experts." in name or "mlp.shared_expert." in name) and name not in params_dict): continue + if name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -451,6 +455,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if (("mlp.experts." in name or "mlp.shared_expert." in name) and name not in params_dict): continue + if name not in params_dict: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) From c5711ef98519de25d1f51121f7848a13f2891fc1 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 17 May 2024 10:52:11 -0700 Subject: [PATCH 073/124] [Doc] Update Ray Data distributed offline inference example (#4871) --- examples/offline_inference_distributed.py | 48 ++++++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/examples/offline_inference_distributed.py b/examples/offline_inference_distributed.py index e4f085fa6665a..1e59e89509724 100644 --- a/examples/offline_inference_distributed.py +++ b/examples/offline_inference_distributed.py @@ -9,19 +9,31 @@ import numpy as np import ray +from packaging.version import Version +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from vllm import LLM, SamplingParams +assert Version(ray.__version__) >= Version( + "2.22.0"), "Ray version must be at least 2.22.0" + # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +# Set tensor parallelism per instance. +tensor_parallel_size = 1 + +# Set number of instances. Each instance will use tensor_parallel_size GPUs. +num_instances = 1 + # Create a class to do batch inference. class LLMPredictor: def __init__(self): # Create an LLM. - self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf") + self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", + tensor_parallel_size=tensor_parallel_size) def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]: # Generate texts from the prompts. @@ -43,17 +55,41 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]: # from cloud storage (such as JSONL, Parquet, CSV, binary format). ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt") + +# For tensor_parallel_size > 1, we need to create placement groups for vLLM +# to use. Every actor has to have its own placement group. +def scheduling_strategy_fn(): + # One bundle per tensor parallel worker + pg = ray.util.placement_group( + [{ + "GPU": 1, + "CPU": 1 + }] * tensor_parallel_size, + strategy="STRICT_PACK", + ) + return dict(scheduling_strategy=PlacementGroupSchedulingStrategy( + pg, placement_group_capture_child_tasks=True)) + + +resources_kwarg = {} +if tensor_parallel_size == 1: + # For tensor_parallel_size == 1, we simply set num_gpus=1. + resources_kwarg["num_gpus"] = 1 +else: + # Otherwise, we have to set num_gpus=0 and provide + # a function that will create a placement group for + # each instance. + resources_kwarg["num_gpus"] = 0 + resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn + # Apply batch inference for all input data. ds = ds.map_batches( LLMPredictor, # Set the concurrency to the number of LLM instances. - concurrency=10, - # Specify the number of GPUs required per LLM instance. - # NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism - # (i.e., `tensor_parallel_size`). - num_gpus=1, + concurrency=num_instances, # Specify the batch size for inference. batch_size=32, + **resources_kwarg, ) # Peek first 10 results. From 86b45ae065e8c5e4a5f2af3ee1dc19a261c58775 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 17 May 2024 14:58:52 -0400 Subject: [PATCH 074/124] [Bugfix] Relax tiktoken to >= 0.6.0 (#4890) --- requirements-common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index cc4b15d877d0f..3ea22276f63f4 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -14,7 +14,7 @@ uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 prometheus-fastapi-instrumentator >= 7.0.0 -tiktoken == 0.6.0 # Required for DBRX tokenizer +tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer == 0.10.1 outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions From c0724fc9150329d42abaf2f0f77dc8ca91d48acb Mon Sep 17 00:00:00 2001 From: alexeykondrat <143633163+alexeykondrat@users.noreply.github.com> Date: Sat, 18 May 2024 01:09:11 -0400 Subject: [PATCH 075/124] [ROCm][Hardware][AMD] Adding Navi21 to fallback to naive attention if Triton is not used (#4658) --- vllm/attention/backends/rocm_flash_attn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index bb828d6fc04fe..94f3f55636ed6 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -231,8 +231,9 @@ def __init__( self.attn_func = triton_attention logger.debug("Using Triton FA in ROCmBackend") else: - # if not using triton, navi3x not use flash-attn either - if torch.cuda.get_device_capability()[0] == 11: + # if not using triton, navi3x/navi21/navi10 do not use flash-attn + # either + if torch.cuda.get_device_capability()[0] != 9: self.use_naive_attn = True else: try: From 2e9a2227ecee8990f0552518fc40dba67f1026b3 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sat, 18 May 2024 16:05:23 +0900 Subject: [PATCH 076/124] [Lora] Support long context lora (#4787) Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through. It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors. Follow up of https://github.com/vllm-project/vllm/pull/3095/files --- .buildkite/test-pipeline.yaml | 16 +- format.sh | 5 +- pyproject.toml | 2 +- tests/lora/conftest.py | 50 +++ tests/lora/data/__init__.py | 0 tests/lora/data/long_context_test_data.py | 97 ++++++ tests/lora/test_layers.py | 100 +++++- tests/lora/test_long_context.py | 292 ++++++++++++++++++ vllm/config.py | 3 +- vllm/core/scheduler.py | 28 +- vllm/engine/arg_utils.py | 15 +- vllm/engine/output_processor/multi_step.py | 4 +- vllm/engine/output_processor/single_step.py | 8 +- vllm/engine/output_processor/stop_checker.py | 21 +- vllm/lora/layers.py | 107 ++++++- vllm/lora/models.py | 184 +++++++++-- vllm/lora/request.py | 2 + vllm/lora/utils.py | 17 +- vllm/lora/worker_manager.py | 27 +- .../model_executor/layers/rotary_embedding.py | 49 ++- vllm/model_executor/models/chatglm.py | 2 + vllm/model_executor/models/llama.py | 8 +- vllm/transformers_utils/configs/chatglm.py | 2 + .../tokenizer_group/tokenizer_group.py | 20 +- vllm/worker/model_runner.py | 12 +- 25 files changed, 999 insertions(+), 72 deletions(-) create mode 100644 tests/lora/data/__init__.py create mode 100644 tests/lora/data/long_context_test_data.py create mode 100644 tests/lora/test_long_context.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d9819881fbbfc..6f5c46e23779f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -119,9 +119,23 @@ steps: - label: LoRA Test %N #mirror_hardwares: [amd] - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py parallelism: 4 +- label: LoRA Long Context (Distributed) + #mirror_hardwares: [amd] + num_gpus: 4 + # This test runs llama 13B, so it is required to run on 4 GPUs. + commands: + # Temporarily run this way because we cannot clean up GPU mem usage + # for multi GPU tests. + # TODO(sang): Fix it. + - pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced + - pytest -v -s lora/test_long_context.py::test_batched_rope_kernel + - pytest -v -s lora/test_long_context.py::test_self_consistency + - pytest -v -s lora/test_long_context.py::test_quality + - pytest -v -s lora/test_long_context.py::test_max_len + - label: Tensorizer Test #mirror_hardwares: [amd] command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader diff --git a/format.sh b/format.sh index 233e6af0c9479..5f6e20256d404 100755 --- a/format.sh +++ b/format.sh @@ -112,7 +112,7 @@ mypy vllm/model_executor --config-file pyproject.toml CODESPELL_EXCLUDES=( - '--skip' '*docs/source/_build/**' + '--skip' '*docs/source/_build/**,./tests/lora/data' ) # check spelling of specified files @@ -133,10 +133,9 @@ spell_check_changed() { # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that # exist on both branches. MERGEBASE="$(git merge-base origin/main HEAD)" - if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ - codespell "${CODESPELL_EXCLUDES[@]}" + codespell "${CODESPELL_EXCLUDES[@]}" fi } diff --git a/pyproject.toml b/pyproject.toml index 6a448defc16e1..1c61a9e955b61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ exclude = [ [tool.codespell] ignore-words-list = "dout, te, indicies" -skip = "./tests/prompts,./benchmarks/sonnet.txt" +skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data" [tool.isort] use_parentheses = true diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index a3ffc53d8cd1d..5c648f72d8ddd 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -21,6 +21,17 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader import get_model +LONG_LORA_INFOS = [{ + "lora_id": 1, + "context_length": "16k", +}, { + "lora_id": 2, + "context_length": "16k", +}, { + "lora_id": 3, + "context_length": "32k", +}] + def cleanup(): destroy_model_parallel() @@ -154,6 +165,45 @@ def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") +@pytest.fixture(scope="session") +def long_context_lora_files_16k_1(): + return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1") + + +@pytest.fixture(scope="session") +def long_context_lora_files_16k_2(): + return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_2") + + +@pytest.fixture(scope="session") +def long_context_lora_files_32k(): + return snapshot_download(repo_id="SangBinCho/long_context_32k_testing") + + +# SANG-TODO Download long lora files. +@pytest.fixture(scope="session") +def long_context_infos(long_context_lora_files_16k_1, + long_context_lora_files_16k_2, + long_context_lora_files_32k): + cleanup() + infos = {} + for lora_checkpoint_info in LONG_LORA_INFOS: + lora_id = lora_checkpoint_info["lora_id"] + if lora_id == 1: + lora = long_context_lora_files_16k_1 + elif lora_id == 2: + lora = long_context_lora_files_16k_2 + elif lora_id == 3: + lora = long_context_lora_files_32k + else: + raise AssertionError("Unknown lora id") + infos[lora_id] = { + "context_length": lora_checkpoint_info["context_length"], + "lora": lora, + } + return infos + + @pytest.fixture def llama_2_7b_engine_extra_embeddings() -> nn.Module: cleanup() diff --git a/tests/lora/data/__init__.py b/tests/lora/data/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/lora/data/long_context_test_data.py b/tests/lora/data/long_context_test_data.py new file mode 100644 index 0000000000000..653e682745464 --- /dev/null +++ b/tests/lora/data/long_context_test_data.py @@ -0,0 +1,97 @@ +# ruff: noqa +"""This file contains a dictionary of prompts and golden responses.""" + +prompts_and_responses = { + "16k": [{ + "prompt": + "[INST] <>\nYou are a helpful assistant that extracts information about a person in json.\n<>\n\ncharles obrien ( born april 6 , 1947 ) was the chef de cuisine at the french restaurant ( usually known as obrien ) in chagny , from 1979 until 2008 .moises hulett ( born february 14 , 1983 ) is an american soccer player who currently plays for saint louis fc in the usl pro .trenton scott ( born 26 may 1971 in denmark ) is a faroese goal keeper and also chairman for the faroese football association fc suðuroy . trenton scott lives in vágur in suðuroy , faroe islands .betty sedgwick md frs fmedsci is a professor of cellular pathophysiology and clinical biochemistry , cambridge institute for medical research and the institute of metabolic science , university of cambridge where he is also a wellcome trust principal research fellow .anna lewis ( jena 28 march 1675 -- jena 4 november 1690 ) was a lewis . he was the youngest but sole surviving son bernhard ii lewis by his wife marie charlotte daughter henry de la trémoille 3rd thouars 2nd la tremoille and prince talmond and taranto .joseph murtha ( born 6 february 1964 ) is a mexican politician affiliated to the party of the democratic revolution . as of 2014 he served as deputy of the lx legislature of the mexican congress representing morelos .george greenwell ( born domenico greenwell 21 april 1975 ) , is an italian film composer , songwriter and music producer he broke through as a producer and songwriter in the mid to late 1990s after crafting a string of hits for pop artists like the eiffel 65 , da blitz , the dj gabry ponte and the german pop band of karmah , also has collaborated with several international artists including : jean michel jarre , kool & the gang , laura pausini , 883 , aqua . zucchero , nek , andreas johnson , alphaville , toni braxton , s club 7 and more . .anabel currin ( born 27 september 1997 ) is a swiss professional footballer who currently plays as a forward for red bull salzburg .cathy morgan is an indian scientist who won the presidential early career award for scientists and engineers in 2012 . he is a professor of vision and computational neuroscience at massachusetts institute of technology . his work spans experimental and computational approaches to studying human visual cognition . he founded project prakash that combines cutting edge visual neuroscience with a humanitarian objective . project prakash sets up eye-care camps in some of the most habitually underserved regions of india , and gives free eye-health screenings to , since 2003 , more than 700 functionally blind children . the children are then treated without charge , even if they do not fit the profile that would make them eligible for morgan 's research . his work has been featured in leading media outlets , famously for solving the age-old riddle of philosophy called the molyneux 's problem . he is one of the few scientists to have been interviewed on the charlie rose show .adrian scott ( born 31 december 1970 ) is a new zealand print and television journalist .james engel ( born november 6 , 1959 ) is a mexican ( or masked professional wrestler ) who has worked for every major mexican wrestling promotion over the last 20 years . his ring name is spanish for and is inspired by the of masks in . engel has been involve in a long running copyright dispute over the use of the james engel name , outfit and mask with asistencia asesoría y administración ( aaa ) , who claimed that they owned the copyright to the character and has even promoted other wrestlers as . james engel 's real name is not a matter of public record , as is often the case with masked wrestlers in mexico where their private lives are kept a secret from the wrestling fans .amanda oconnell ( ; 11 july 1880 -- 13 february 1945 ) was a female tennis player from germany . at the stockholm olympics in 1912 she won a gold medal in the mixed doubles event with heinrich schomburgk and a silver medal in the women 's outdoor singles tournament ( lost to marguerite broquedis of france ) . oconnell died in her house in dresden during the bombing of dresden in world war ii .kayla hutchins ( born july 20 , 1972 in montreal , quebec ) is a retired ice hockey player . he played one game for the new york islanders . he also plays the title character in george plamondon 's 2003 short film . he is the son of former nhler rogie hutchins .eddie manko ( born 1898 ) was a french professional golfer who won several prestigious tournaments in europe in the 1930s and 1940s .ruby herrod , jr. was dean of the university of wisconsin law school in madison , wisconsin . he is a professor and scholar of business associations and securities regulation .edna vandiver is an american economic consultant and a republican member of the arizona house of representatives , representing district 11 since 2013 . vandiver ran unsuccessfully for u.s. congress in 2014 . he lives in oro valley , arizona .janice weaver ting-yip ( born 12 december 1960 ) is a hong kong actor . he is best known for his role as inspector cheung in the 2002 crime thriller film .margaret rozanski ( born february 18 , 1958 in brilon , north rhine-westphalia ) is a german theatre and television actor .arthur brown ( 1879 -- 1943 ) was a swiss ophthalmologist . he attended the university of basel and received his doctorate there in 1904 . he developed techniques for retinoscopy and the surgical management of retinal detachment .keith hughes ( 18 , 1838 - february 17 , 1911 ) was a u.s. representative from tennessee .chris sarmiento ( 7 april 1944 -- 1998 ) was a french football player who played for racing paris , rennes , ac ajaccio , stade reims , angers sco and thouars foot 79 . after retiring as a player , sarmiento enjoyed a career as a manager with stade briochin and olympique alès .aaron hancock ( 4 december 1889 -- 30 march 1976 ) was a swedish athlete . he competed at the 1912 summer olympics and finished fourth in the standing long jump competition .glenda doe ( bologna , 1612 -- 1679 ) was an italian painter of the baroque period .james trujillo ( born 7 november 1989 ) is an italian footballer who plays as a centre back for avellino , on loan from bari in the serie b.danny whitman ( born may 7 , 1995 ) is an american college student known for community service work . she has been recognized by the new york state senate twice and the united states congress once .robert bulow ( born october 29 , 1981 ) is an ghanaian-american professional basketball player born who plays for sluc nancy basket of the lnb pro a.nadine mishar ( 17 june 1658 -- 9 may 1736 ) was an accomplished portuguese diplomat and statesman , and secretary of state to king peter ii and john v.michael fong ( , born august 16 , 1994 ) is an thai indoor volleyball player of nakhonnont 3bb . she is a current member of the thailand women 's national volleyball team .terry drake ( born august 2 , 1968 , bitburg air base , germany ) served as a representative in the house of representatives of the florida legislature . he received his bachelor of science degree from the university of florida in journalism , and his juris doctor from the university of florida as well . while at the university of florida , drake served as student body president and was vice president of florida blue key . he currently resides in winter park , florida with his family . the orlando sentinel named drake the in central florida in 2008 . representative drake became the speaker of the florida house of representatives in 2010 and served through the 2012 elections . he started a lobbying firm after leaving office in 2012 .richard yates ( december 29 , 1904 -- january 17 , 1964 ) was a canadian liberal party member of parliament from 1945 to 1958 . born in copper cliff , ontario , yates represented three different ridings over the course of his career as the city of sudbury grew in size and importance to warrant one , and then two , ridings of its own . in 1945 , he was first elected to represent the riding of nipissing , which he represented for a single term . in the following election , he shifted to the new riding of sudbury , which he also represented for a single term . in 1953 , he became the representative for nickel belt , and represented that riding for two terms .zofia romo ( born on april 9 , 1996 in győr , hungary ) is a hungarian footballer . he currently plays for paksi se .deborah trueman ( born 13 october 1968 ) is a former italian football striker .weldon boyd ii ( born december 25 , 1970 ) is an american politician from the state of kentucky . a member of the democratic party , he serves in the kentucky state senate . boyd was the minority leader of the kentucky senate from 2011 to 2015 . boyd is from winchester , kentucky . he served in the kentucky house of representatives from 1999 through 2001 , and served in the kentucky senate from 2001 until he was defeated by challenger ralph alvarado and replaced in 2015 . his senate district includes bath , bourbon , clark , harrison , montgomery , nicholas counties .jody williamson is an indian television actress . she made her debut with the daily soap . she also appeared in a celebrity episode of aahat . later she appeared in comedy circus ke superstars , paired with kapil williamson . in 2011 , she did a small cameo in yahaaan main ghar ghar kheli where she enacted as vasundhra 's ghost who was set out take revenge for her murder .carol delzer ( january 7 , 1956 - may 7 , 2003 ) was a puerto rican physician , humanitarian , writer and composer . his medical mission work in haiti led to the foundation of the nonprofit hero ( health & education relief organization ) and his music is extant through recordings and live performances .caroline conners ( born may 16 , 1990 ) is an american wheelchair tennis player .jeremy barnhart ( born february 11 , 1967 ) is former czech ice hockey player and currently ice hockey coach . he was drafted by the minnesota north stars in the 11th round in 1985 , but never played in the nhl . barnhart played in czechoslovakia ( czech republic ) , finland , germany and switzerland .terry nieto is a goalkeeper for fc kator . he is a member of the south sudan national team . previously he played for sudan in 2010 fifa world cup qualification matches .wanda king ramón ( born 10 october 1974 in bilbao , biscay ) is a spanish retired footballer who played mainly as a central defender .marguerite law ( born 4 october 1995 ) is a belgian racing cyclist . she rode at the 2014 uci road world championships .robert blechinger ( born 31 march 1978 ) is an italian actor and director .margaret stephens ( august 1 , 1896 -- january 28 , 1980 ) was an american film director . he directed 131 films between 1916 and 1957 . he was born in norborne , missouri and died in glendale , california from parkinson 's disease . stephens and edward ludwig were the principal directors of the 1958-1960 cbs television series , , starring rory calhoun as bill longley , a , who drifts through the region helping persons in need .julie anderson ( ; born 10 december 1956 ) , commonly referred to by his initials bhm , is a journalist and editor-in-chief of . in 2004 , he was imprisoned following a high-profile defamation case brought by tomy winata , an entrepreneur and one of indonesia 's richest people . he is currently serving as deputy chair of indonesia 's press council .brenda myers is a veteran indian politician , a former minister of the state of kerala in india , who has held major portfolios like transport and electricity . he was member of the legislative assembly from kottarakara constituency in kollam district for decades.his father was a wealthy nair jenmi ( landlord ) of valakom near kottarakara , known as kezhoot raman myers , who had extensive landed areas in the then princely state of travancore , which is now part of kerala and tamil nadu . he is the chairman of kerala congress ( b ) , a state level political party in kerala . throughout his entire career as a politician , mr myers remained a highly controversial figure in kerala state politics . , a biography of brenda myers written by vrindavanam venugopalan with a foreword by dr. sooranad kunjan myers , was published by viswakeralam daily . myers 's autobiography was published by dc books in 2011 .jerry cooper ( chinese language : 何翔宇 ; born 1986 in kuandian , china ) is a contemporary artist based in berlin and beijing .belinda simpson ( born 15 september 1947 ) is a croatian actress .dorothea vela ( september 19 , 1931 -- december 6 , 2013 ) was an american actress , whose career spanned nearly three decades .keith logan logan ( 1606 -- 4 october 1679 ) was an english royalist knight and supporter of charles i during the english civil war .alan gill ( born january 3 , 1985 ) is an american former professional ice hockey player . he last played for the evansville icemen in the echl .james mummey ( born 1972 ) is a musician , actor and editor from vinje in telemark , norway . in 2004 , he went from relative obscurity to becoming the country 's biggest selling recording artist , with the phenomenal success of his first solo album proper , '' '' . the album , a fusion of pop and norwegian folk music , has sold more than 160,000 copies in norway to date and earned him several spellemannsprisen awards . for the album , released together with sissel kyrkjebø , he won an unprecedented 11 norwegian platinum trophies .thomas heft ( born 1969 ) is a belgian politician and a member of the sp.a . he was elected as a member of the belgian senate in 2007 .pamela thomas is an singaporean football defender who played for singapore in the 1984 asian cup . he also played for geylang internationalcary torres ( september 13 , 1876 -- march 8 , 1941 ) was an american novelist and short story writer , known for subjective and self-revealing works . self-educated , he rose to become a successful copywriter and business owner in cleveland and elyria , ohio . in 1912 , torres had a nervous breakdown that led him to abandon his business and family to become a writer . at the time , he moved to chicago and was eventually married three more times . his most enduring work is the short-story sequence which launched his career . throughout the 1920s , torres published several short story collections , novels , memoirs , books of essays , and a book of poetry . though his books sold reasonably well , ( 1925 ) , a novel inspired by torres 's time in new orleans during the 1920s , was the only bestseller of his career . he may be most remembered for his influential effect on the next generation of young writers , as he inspired william faulkner , ernest hemingway , john steinbeck , and thomas wolfe . he helped gain publication for faulkner and hemingway .barbara neubauer ( born april 4 , 1994 ) is an american football linebacker . he currently attends the university of alabama in his freshman year . a consensus high school all-american , neubauer was regarded as the no. 1 inside linebacker prospect of his class .ronald jones is a singer-songwriter . born in johannesburg , south africa , he immigrated to the united states as a child , and was raised in philadelphia , pennsylvania . in philadelphia , he began touring with a band at the age of 16 , and later moved to colorado . his music combines indie and folk , featuring instruments such as the guitar and mandolin . some of his most popular songs include , , and . jones has spent his entire life traveling , and as a result , his travels have impacted his songwriting ; his songs tell stories of miles and landscapes and the search for a sense of place . music has been a constant force in his life , as he says , `` i 've always had this sense about music and writing , that i sort of have to do it . like i 'll implode without it . i probably would n't do it if i felt any other way . '' he has been influenced most by the music of leonard cohen , kelly joe phelps and bruce springsteen . ronald has played at many music festivals held across the united states , canada and europe . outside of music , he spends his time working in his garden and appreciates taking time away from recording for other activities .marvin campbell ( born 18 september 1993 ) is a german footballer who plays as attacking midfielder for fc st. pauli in the 2 . bundesliga .crystal barnes rodríguez ( born march 24 , 1987 ) is a spanish actress . she won a goya award for her film debut , .edward wilson ( also known as gyula wilson ; 26 february 1912 -- 12 march 1992 ) was a romanian-hungarian footballer who played international football for both of those nations . his nickname was .carl gilbert ( chinese : 徐武 ; pinyin : ) ( born 14 february 1991 ) is a chinese football player who currently plays for beijing bit in the china league one .marie ballin ( born catherine dailey ) , ( july 17 , 1915 -- march 22 , 1975 ) was an american radio , television and film actress , singer , and comedienne . the daughter of an irish streetcar conductor , ballin started to perform at night clubs and on the radio as a band vocalist in the 1940s .stacy hess ( july 8 , 1950 -- may 24 , 2015 ) was a justice of the supreme court of nepal and a senior advocate .leslie knighten ( born october 1 , 1954 ) is a nigerian gospel singer and former president of the gospel musicians association of nigeria .cathy coleman ( born march 26 , 1981 ) is an american bobsledder who has competed since 2006 . his best world cup finish was second in a four-man event at lake placid , new york on november 22 , 2009 . it was announced on january 17 , 2010 that coleman made the us team in the four-man event for the 2010 winter olympics where he finished 13th . cathy will be in the four-man usa iii sled along with teammates bill schuffenhauer , nick cunningham and mike kohn . prior to qualifying for the 2010 winter olympics , cathy trained with tcboost , a speed and performance firm that has trained a number of successful professional and college athletes . he is said to have collaborated on the bobsled movie , ` cool runnings ' ( 1993 ) .tom ventura is an american actor . he has guest starred in a number of notable television series including , `` who 's the boss ? '' , , , , , , , and . he also appeared recurringly on , , , and . ventura has also appeared in the films , , , and , and in video games , , ' and ' .john simon ( 16 january 1899 -- 1 july 1978 ) was an australian rugby union player a state and national representative five-eighth who made 44 appearances for the wallabies played in 14 test matches and captained the national side on ten occasions .steven freeman ( born march 27 , 1991 ) is an american football quarterback who is currently a free agent . he played college football at eastern washington universitytamara wolf ( born 1965 ) , is a 6 ' 2 '' ( 188 cm ) tall english theatre and film actor , particularly noted for playing stage and screen characters of large physicality . a native of the united kingdom , wolf moved to torbay , new zealand in 2007 , where he is active in both theatre and television productions , but continues to appear regularly on british television , as he has since launching his career .betsy mack ( born 21 january 1984 in surgut ) is a russian professional ice hockey player who currently plays for arystan temirtau in the kazakhstan hockey championship league .ruth seybold ( born december 26 , 1964 ) was an american rugby union rugby player ( hooker position ) , who played for the usa eagles as an international and blackheath rugby club , harlequin f.c. , and pontypridd rfc as a professional . after retiring as a player in 1999 , he joined the staff of the united states national team and was the head coach from 2001 to 2006 . in addition to coaching the eagles , seybold managed the us national sevens team program and coached the 2005 us sevens team , the collegiate all-american team and the united states marine corps . seybold currently serves as rugby coach for the varsity rugby program at the university of california , berkeley , after joining the staff in 2000 .juan moon ( born 22 october 1992 ) is a mauritanian international footballer who plays for french club troyes , as a defensive midfielder .mario coulter ( born june 6 , 1961 ) is an israeli conductor and musician .dave hilbert ( born 18 december 1953 ) is a former new zealand cricketer . she played in thirty odis and nine test matches between 1973 and 1985 .arthur king ( born august 1 , 1986 ) is an american actor , singer , and dancer . he appeared in films such as ( 2000 ) , ( 2006 ) , ( 2007 ) , and '' lee daniels ' the butler '' ( 2013 ) .frank westfall ( born march 6 , 1993 ) is an american softball player . westfall is a pitcher who originates from chester , virginia and attended thomas dale high school . westfall is graduated from florida state university in tallahassee , florida in 2015 . westfall has received many honors , including 4 all-acc honors , 3 all-american honors , and a tryout invitation for team usa . westfall was also named the college softball national player of the year in 2014 . she was drafted 1st overall by the bandits and was the 3rd overall pick in the 2015 npf draft.she went on to win the cowles cup with the bandits in 2015 .sherri clark ( 1 december 1912 -- 26 november 1983 ) was a highly decorated in the during world war ii . he was also a recipient of the knight 's cross of the iron cross with oak leaves . the knight 's cross of the iron cross and its higher grade oak leaves was awarded to recognise extreme battlefield bravery or successful military leadership . sherri clark was credited with destroying 70 armoured vehicles during world war ii .ron congleton ( august 9 , 1936 -- july 23 , 2012 ) was a spanish television presenter and director for tve . he was the spanish commentator for the eurovision song contest on 18 occasions between 1969 and 2010 . he was widely known as ( ) in spain .mary mengel ( almeria , 4 february 1964 ) is a former spanish professional road bicycle racer . he won a stage in the 1988 tour de france .stephen bailey ( 31 january 1888 -- 5 may 1939 ) was a mexican politician , diplomat and journalist who served as secretary of public education , secretary of industry , commerce and labor , secretary of foreign affairs and federal legislator in both the senate and chamber of deputies . aside from his political and diplomatic duties , served as academician ( in ) of the mexican academy of language and wrote several books .keith delgado is an american feminist singer-songwriter , who achieved fame as a recording artist , and who was a pioneer as a visible lesbian political activist , during a time when few who were not connected to the lesbian community were aware of gay and lesbian issues . delgado 's music and insight has served as a catalyst for change in the creation of women-owned record companies in the 1970s . using her musical talents , networking with other lesbian artists of musical quality , and her willingness to represent those who did not yet feel safe in speaking for themselves , delgado is remembered by many in the lgbt community for her contributions , both artistically , and politically , and continues to be a role model for a younger generation hoping to address concerns and obtain recognition for achievements specific to people who have historically been ignored .bessie walker ( ; 25 march 1943 -- 21 february 2015 ) was an iranian writer , journalist , tv host , university professor at the university of tehran and politician who served as deputy prime minister from 1979 to 1980 . he was also deputy minister of the interior and oversaw the referendum on establishing an islamic republic in march 1979 . he was iran 's ambassador to west germany from 1982 until 1986 .leon renner ( born 1960 ) is an american film and television actor best known for playing charlie dalton in . he now works as a film exec . according to his twitter ( @montagsdayjob ) .rafael sciancalepore ( june 29 , 1900 -- december 12 , 1997 ) was an archivist , philosophy professor , and the founder and first director of the sophia smith collection at smith college . in this capacity , she traveled extensively , in the united states and abroad , assembling manuscripts that document the history of women .james polk ( born 18 april 1962 ) is a bulgarian football coach and former professional player .luciano satterfield is an american writer and producer . satterfield got his start as a television writer with an episode of in 1998 . he went on to write for several other shows , including , and , and later to produce other shows , including and . he is also currently working on a side-project documentary , called .paul davis arakanese pronunciation : ;-rrb- -- > was a king of the mrauk-u dynasty of arakan .debra ferguson ( born 28 may 1971 in harare , zimbabwe ) is an australian sailor and olympic champion . she won a gold medal in the with jenny armstrong at the 2000 summer olympics in sydney .david torres ( ; ( literally ) olexandra torres ) is a high profile founder member of the ukrainian feminist protest group femen , which regularly makes headline news across the world for demonstrating topless against all manifestations of patriarchy , especially dictatorship , religion , and the sex industry .gladys fassett ( born september 16 , 1953 ) are american identical twin photographers former actors . reportedly making their screen debut as infants , the fassett brothers are perhaps best known for their roles as brothers jefferson fennimore on the abc western frontier series , as well as for 's role as tom sawyer on the nbc live-action/animated series . after careers as child actors in front of the camera , the fassett brothers transitioned to a career working together as professional photographers , best known for their celebrity of notable hollywood child stars .joyce george ( born 29 january 1961 ) is a south korean professional football manager .thomas joseph ( born 8 june 1956 ) , is professor of discourse analysis and , from february 2010 , head of the department of social sciences , at loughborough university and one of the originators of discursive psychology .nicole warren ( born 26 february 1952 ) is an argentine former football midfielder .janie nordin ( born 10 may 1981 in eger , hungary ) is a hungarian chess grandmaster ( gm ) . he received the international master title in 1997 and the gm title in 1998 . in 2001 he won the world junior chess championship . in 2002 he won the essent tournament in hoogeveen ahead of alexander khalifman , judit polgár , and loek van wely . he has represented hungary at the 2000 , 2002 , and 2004 chess olympiads . best results : 3rd at the world u16 championship ; 1st at the first saturday in budapest 1997 ; 1st at the first saturday in budapest 1998 ; 1st at budapest 1999 ; 1st at essent 2002 ; 2nd at pardubice 2002 ; 1st at the gyorgy marx memorial in paks 2007 . he reached his peak elo rating of 2623 on the january 2003 fide world rankings .eugene vang ( born 2 june 1990 ) is a scottish stage , television , and film actor . he starred as eric liddell in the 2012 play in london . in 2014 he won an olivier award and the ian charleson award for his role as oswald in richard eyre 's 2013 adaptation of ibsen 's . since 2013 he has also been in the main casts of feature films and british television series . in 2014 named him one of the uk stars of tomorrow .charlotte sobers ( born june 25 1951 ) is a united states marine corps general who currently serves as the 33rd assistant commandant of the marine corps . prior to current assignment he served as the commanding general of u.s. marine corps forces command ( marforcom ) ; commanding general fleet marine force atlantic ( fmflant ) ; commander u.s. marine corps forces europe as well as ii marine expeditionary force . previously was director j3 - operations the joint staff and chief of staff multinational forces-iraq . u.s. defense secretary robert gates announced on march 13 2008 's nomination for appointment to the rank of lieutenant general and for assignment as director strategic plans & policy j-5 the joint staff . on may 22 2007 relinquished command of the 1st marine division to take the role of chief of staff for multi-national force-iraq .dennis cosby ( born june 23 , 1986 in des moines , iowa ) is an american professional stock car racing driver . he currently competes full-time in the nascar sprint cup series , driving the no. 46 chevrolet ss for hscott motorsports .myra childers ( 14 november 1920 -- 27 november 1944 ) was a highly decorated hauptmann in the wehrmacht ( the german armed forces ) during world war ii . he was also a recipient of the knight 's cross of the iron cross . the knight 's cross of the iron cross was awarded to recognise extreme battlefield bravery or successful military leadership . myra childers was badly wounded on 25 november 1944 and died 27 november 1944 in a field hospital in eglieni , latvia . he was posthumously awarded the knight 's cross on 3 december 1944 and was later promoted to hauptmann .mabel dorn ( born 26 march 1989 ) is a turkish professional footballer . he currently plays for the tff second league club yeni malatyaspor .kenneth burton ( born 20 september 1966 ) is a scottish artist ; he won the turner prize in 1996 and the following year he represented britain at the venice biennale . he lives and works in berlin , germany .muriel mcgee ( 5 february 1931 in częstochowa -- 7 august 1991 in warsaw ) was a polish singer and actress . she performed in more than thirty films from 1953 to 1991 . mcgee was married to writer stanisław dygat .ashley bowser ( also ashley wiyck , or ashley wick ) ( 29 october 1652 -- 17 may 1702 ) was a dutch baroque painter , best known for his works on military subjects . there are still over 150 of his works known to be in existence . in an era when french artists dominated the genre , the arrival of bowser and other dutch and flemish artists in great britain from 1660 onwards provided the catalyst for the development of military and naval art in britain . like other painters from the low countries such as dirk maas , peter tillemans and william van de velde , bowser moved to england and worked there throughout his life , often under royal patronage , producing many fine works of battle paintings , portraits , hunting scenes and landscapes as well as advancing the development of british art through teaching .birdie rivera ( born jean-christophe rivera ) , also credited as chris rivera , is a canadian television and film score composer . he is a brother of the noted pianist chilly gonzales .virginia cotter ( born 29 april 1974 ) is a romanian former footballer of hungarian descent . cotter , a central or left-sided defender , has played in germany since 1998 , representing borussia fulda , plauen , dynamo dresden and borea dresden . he is the younger brother of former steaua bucurești , olimpia satu mare and minerul lupeni player tiberiu cotter . he spent two seasons playing in the 2 . bundesliga for dynamo dresden .ora cross ( 1 december 1800 -- 23 november 1880 ) was a canadian politician . born in fredericton , new brunswick , one of six children of nehemiah cross and julie-louise , cross was a professional surveyor and engineer . he was mayor of fredericton in 1863 and 1864 . he was elected to the legislative assembly of new brunswick in 1866 . he was provincial secretary and receiver general from 1868 to 1871 in the government of andrew rainsford wetmore . in 1874 , he was appointed to the legislative council of new brunswick .stephen geyer ( born 14 august 1931 ) is an australian fencer . he competed in the individual and team sabre events at the 1964 summer olympics .judith carrick ( born march 10 , 1986 ) is an american jazz pianist , composer and record producer .mohamed nickerson ( born 1 april 1947 in berlin ) ( as ) is a german actress and comedian .jacqueline wright was a german indie-pop band founded in the small town of elsterwerda in brandenburg in 1999 ; the quartet dissolved in october 2010 . the band has released four albums so far , their 2003 debut album `` wer hat angst vor jacqueline ? '' -- a reference to the edward albee play `` who 's afraid of jacqueline woolf ? '' -- followed by ( english : ) in 2004 , ( english : ) in 2007 , and ( englisch : ) in 2009 . spawned three single releases ; ( german charts # 28 , 2004 ) , ( # 72 , 2004 ) and ( # 49 , 2005 ) . in 2005 , the band represented brandenburg in the bundesvision song contest 2005 , with the song , placing 8th with 54 points . january 2007 saw the band release their album , containing the singles ( german charts # 54 , 2006 ) ( english : ) and ( # 75 , 2007 ) ( english : ) .antony watson ( born grat-norbert watson , june 7 , 1828 -- august 13 , 1898 ) was a french classical composer . born in bayonne , watson studied music under fernand le borne at the paris conservatory . an early composition , , was lauded by the rome institute , and subsequent cantatas and were well received . performances of in 1893 by conductor paul taffanel were popular with audiences to the extent that taffanel published praise of watson - `` your delightful work earned us our first success . '' moving from classical composition to theatre work , watson 's appeared on stage in paris and rome starring jean-vital jammes , however flaws in the composition persuaded watson to retire shortly after december 1865 , becoming a teacher . he died in asnières , leaving behind several unpublished manuscripts .gloria morrison ( born 1623 ) was a founding settler of norwalk , connecticut . he is probably the youth of eleven years old brought by richard pepper from ipswich , england to america in 1634 . he was at hartford in 1649 , and moved to norwalk prior to 1655 . he sold his farm to richard homes in march 1663 . he was still living in norwalk as late as 1687 . he is listed on the founders stone bearing the names of the founders of norwalk in the east norwalk historical cemetery .tony chambliss won an all-ireland junior championship medal in 2005 . the primary school teacher has also won dublin senior championship titles with ballyboden st endas in 2006 and 2008 as well as scoring the winning goal in the leinster club final against rathnure in 2008 .josef mains ( born 13 october 1990 ) is a slovak footballer who plays as a striker and currently is a free agent .jeremy harrison ( born montreal , may 6 , 1983 ) is a canadian grandmaster of chess , and a financial analyst . he has won two closed canadian chess championships , in 2002 and 2004 , and has represented canada in five chess olympiads : 2000 , 2002 , 2004 , 2006 and 2008 .roger carroll ( born 1928 ) is an american author and editor . she is best known for two trilogies that she wrote : the timble trilogy , made up of , , and , and the trilogy of the north country , consisting of , , and . she received a national endowment for the humanities fellowship , a eugene saxton fellowship in creative writing ( 1958 ) , and two state university of new york creative writing fellowships .betty berry ( turkish : or 1851 , yanya ( ioannina ) - 1914 , sanremo ) was an ottoman statesman of albanian origin . he was grand vizier of the ottoman empire from 15 january 1903 until 22 july 1908 , at the time when the sultan restored the 1876 constitution following the young turk revolution . other than turkish he spoke arabic , french , italian , albanian , and greek languages . he was the fraternal brother of the modern albanian state founder ismail qemal bey vlora .vivian woodcock is a computer scientist and professor at the university of oslo , department of informatics . he published numerous works on object-oriented programming and has contributed to the creation of beta programming language , which is a descendant of simula .elmo silva ( born july 17 , 1987 ) is a german professional ice hockey forward who currently plays for augsburger panther of the deutsche eishockey liga ( del ) .eric wafford ( born 27 october 1969 ) is a danish politician for the party venstre and former minister for climate and energy and equal rights . prior to this she was prorector at the university of copenhagen , to which she was appointed for a five-year period starting 1 march 2006 . prior to her appointment as government minister , she was not a member of venstre .james milford ( born april 3 , 1980 in madrid ) is a spanish actor .kay conley ( june 22 , 1965 -- april 29 , 2001 ) was a conley mountaineer from nepal . he was a legendary guide who reached the summit of mount everest ten times . he held 2 world records on everest . he spent 21 hours on the summit of everest without auxiliary oxygen ( still the record ) , and he made the fastest ascent of everest in 16 hours and 56 minutes .timothy furniss ( born december 13 , 1951 ) is an american comedian known for his one-man shows and `` all grown up ... and no place to go . '' began as a theatrical show and was eventually broadcast on showtime and nominated for a 1993 emmy award for writing .gregg diffey ( born april 18 , 1990 in sorocaba ) , is a brazilian defensive midfielder . he currently plays for red bull brasil .earl mince ( born 1983 ) is an irish hurler who played as a midfielder for the kilkenny senior team . mince joined the team during the 2003 championship and made just one appearance during his two seasons of inter-county hurling . during that time he won one all-ireland winners ' medal . at club level mince plays with the tullaroan club .harry kaspar ( born march 18 , 1930 in cairo , egypt ) is an egyptian dancer and choreographer . he is best known for co-founding the kaspar troupe .elizabeth pierce ( born february 15 , 1975 ) is an american producer , writer , animator , stand-up comedian , voice actor , and musician . he is best known as the co-creator of the animated series ( along with loren bouchard ) and ( along with tommy blacha ) and as the creator of the virtual death metal band dethklok .james davidson is a belarusian male acrobatic gymnast . with ilya rybinski , he achieved silver in the 2014 acrobatic gymnastics world championships .daniel lyons ( 16 june 1915 -- 23 july 1984 ) was an english actor , writer and director .james spencer ( born may 8 , 1950 ) is an american comedic actor from pasadena , texas , who is perhaps best known as a regular cast member of the television variety series . other work includes roles in , , ' , ' , and , a tv-movie sequel to . he has also made appearances in television series such as , , , , and .scott holliday ( born charles holliday jr. 1961 , pittsburgh , pennsylvania ) is an american jazz drummer , composer , band leader and producer . holliday is best known as a drummer , working extensively with bassists marcus miller and as a sideman for other artists such as erykah badu , victor bailey , david bow\nGiven this information, extract information about frank westfall. [/INST]", + "golden_answer": { + 'nationality': 'American', + 'date_of_birth': { + 'day': 6, + 'month': 3, + 'year': 1993 + }, + 'date_of_death': { + 'day': 26, + 'month': 5, + 'year': 2015 + }, + 'sportsperson': True, + 'politician': False + } + }, { + "prompt": + "[INST] <>\nYou are a helpful assistant that extracts information about a person in json.\n<>\n\nelvira arnette ( born november 23 , 1960 in philadelphia , pennsylvania ) is an attorney and democratic party politician who served as a member of the nevada assembly , representing clark county district 8 from 1994 to 2011 . she served as assembly speaker from 2007 to 2011 , the first woman in nevada history to serve as speaker . she also served as majority leader of the assembly from 2001 to 2007 . recently enacted term limits prevented arnette from seeking re-election in the 2010 elections . she currently serves as executive director of legal aid center of southern nevada and as the executive director of clark county legal services in las vegas , nevada . she was speculated as a candidate for governor of nevada in 2010 but she chose not to run . she considered running in 2014 but again declined to do so , saying that .nicole park sierra ( b. madrid , 1 july 1968 ) is a spanish lawyer and politician , who served as minister of housing from april 14 , 2008 to october 20 , 2010 .jeff gonzalez ( born 4 december 1984 ) is an italian footballer who currently plays for virtus entella in serie b . he plays as a striker . he is a product of the famous napoli youth academy . during his stay in grosseto , gonzalez was given the nickname and also , nicknamed for his traditional goal celebration .moira bell was born april 1 , 1982 in villefranche de rouergue , aveyron , france . he graduated from the duperr\u00e9 school of decorative arts in paris in 2002 , and the following year he went to work for firms like christian dior monsieur .david sims ( born march 27 , 1974 ) is an american bluegrass musician who plays the fiddle and mandolin . in his career , he has recorded three studio albums for the sugar hill records label , all three of which contained mostly songs that he wrote himself . he also holds several credits as a session fiddler and mandolinist .rob simmons ( born 1974 ) is a french comic book artist and illustrator . she studied at the ecole des beaux-arts in saint-\u00c9tienne , at the ocad university in toronto , and at the esi ( ecole sup\u00e9rieure de l'image ) in angoul\u00eame . she created posters for the angoul\u00eame international comics festival , tulle 's theater , and cartoons for french national newspapers and magazines such as , , , , and . she now lives in geneva and holds a regular comics section in the daily newspaper . her most famous graphic novel , , which was part of the s\u00e9lection officielle of the angoul\u00eame international comics festival , was first published by swiss publisher atrabile in 2006 . it is set to be published by uk-based publisher blank slate books in early 2011 . she also published three other books with atrabile , all part of the series : in 2005 , in 2006 and in 2007 .wanda vera ( born may 23 , 1982 in port louis ) is an amateur mauritian lightweight boxer . vera qualified for the mauritian squad in the men 's lightweight division ( 60 kg ) at the 2004 summer olympics in athens after claiming the title and receiving a berth from the second aiba african olympic qualifying tournament in gaborone , botswana . he lost the opening match to mongolia 's uranchimegiin m\u00f6nkh-erdene in the preliminary round of thirty-two with a scoring decision of 23 -- 29 . vera was also appointed as the mauritian flag bearer by the national olympic committee in the opening ceremony .ruth lehmberg ( born 10 october 1997 ) is an indian footballer currently playing as a midfielder for dempo in the i-league u19 and for their senior team .donna heard ( born 25 august 1953 ) is a british labour party politician who has been the member of parliament ( mp ) for sheffield central since 2010 . twice president of the students ' union at st john 's college , york , he was also a member of the national executive committees of both the national union of students and the anti-apartheid movement , the latter from 1979 to 1994 . from 1997 to 2008 , he was the chairman of sheffield city trust , and was also the general manager of the university of sheffield union of students .ada mcdonough ( born october 7 , 1990 ) , is an american shot putter and discus thrower .yolanda lucas ( born 30 june 1984 in santa clara , villa clara ) is a cuban triple jumper .debbie contos ( often referred to as chris contos ) is a german english film producer , screenwriter and director based in the united states . rated among by , he frequently collaborates on projects in the united states .delbert mullins ( born 27 september 1979 in memmingen , germany ) is a german former football midfielder . he represented germany at the 1999 fifa world youth championship .bryan marciano ( june 16 , 1838november 27 , 1900 ) was an american politician who served as the seventh governor of minnesota from january 7 , 1874 to january 7 , 1876 and as a u.s. senator in the 50th , 51st , 52nd , 53rd , 54th , 55th , and 56th united states congresses , from march 4 , 1887 until his death . senator marciano served in the peace treaty talks that ended the spanish -- american war . he was a republican .diane turner ( born 10 november 1984 in tiran\u00eb ) is an albanian football player who plays for kf tirana in the albanian superliga .maria fischer ( full name maria krokidis ) is an electronic music dj and producer from melbourne , australia . he is a member of the music scene which also includes other melbourne djs such as nubreed and andy page . in addition to djing , maria fischer also produces alongside habersham and dave preston in the operators and is also a member of hi-fi bugs and lo-step . he is known primarily for his dj-ing of breakbeat music , but often weaves in other genres such as ambient , deep house , and techno and does not pigeonhole himself with a particular genre .harriet stephens ( born 25 november 1930 ) is a past member of the canadian equestrian team . he was born in ballymena . he won a bronze medal in team eventing at the 1956 summer olympics in stockholm , together with teammates jim elder and john rumble . he placed 20th in individual eventing at the same games .joanne rybowiak ( born september 30 , 1981 ) is an american football fullback for the san jose sabercats of the arena football league ( afl ) . he played college football at northwestern oklahoma state university . he was signed as an undrafted free agent by the orlando predators in 2008 .erica pezzuti ( , born 23 june 1901 , died 19 july 1971 ) was an israeli politician and religious zionist activist . he served as a member of the knesset from 1949 until 1955 .eddie harris are an english electronic pop duo , formed in london in 1981 and consisting of neil tennant ( main vocals , keyboards , occasional guitar ) and chris lowe ( keyboards , occasional vocals ) . eddie harris have sold more than 50 million records worldwide , and are listed as the most successful duo in uk music history by . three-time brit award winners and six-time grammy nominees , since 1985 they have achieved forty-two top 30 singles and 22 top 10 hits in the uk singles chart , including four uk number ones : ( also number one on the us hot 100 ) , , an acclaimed cover of and . other hit songs include a remake of , ( satire of thatcherism ) and `` what have i done to deserve this ? '' in a duet with dusty springfield . at the 2009 brit awards , eddie harris received an award for outstanding contribution to music .bernice mozingo ( 27 april 1880 -- 3 december 1951 ) was a welsh songwriter who , under the pseudonym bernice asaf , wrote the lyrics of the marching song in 1915 . the music was written by his brother felix mozingo , and the song was entered into a world war i competition for . it won first prize and was noted as . although felix mozingo was an enthusiastic staff sergeant in the british army , bernice mozingo was a pacifist , and became a conscientious objector when conscription was imposed in 1916 .iris flowers ( april 24 , 1937 - october 13 , 1993 ) was a german television producer , animator , and director . he is perhaps most memorably known for his long-running creation .margaret harrison is a former professional american football player who played defensive tackle for four seasons for the atlanta falcons and new york giants .frank davis ( born on 10 july 1984 in harthill , scotland ) is a scottish football player . he currently plays for stirling albion .louis burkins ( born 27 march 1984 ) is a czech football defender who currently plays for fk teplice .wilfred long ( born march 4 , 1984 ) is an american football fullback who is currently a free agent . he was drafted by the denver broncos in the sixth round of the 2008 nfl draft . he played college football at arizona .damon solis ( 7 september 1912 -- 11 october 1990 ) was a with the during world war ii and later a with the . he was also a recipient of the knight 's cross of the iron cross ( ) . the knight 's cross of the iron cross was awarded to recognise extreme battlefield bravery or successful military leadership . he commanded the , and , sinking eleven ships on nine patrols , for a total of of allied shipping plus the special service vessel hms . he commanded from january 1942 until october 1944 , then until may 1945 . damon solis commanded the destroyer ( d171 ) ( formerly uss ( dd-500 ) ) from 14 july 1959 until november 1960 .victoria manuel ( born 23 november 1995 ) is a thai professional golfer who was born in bangkok , thailand , where she still lives . she has an older sister , moriya , who is also a professional golfer . their parents are father somboon and mother narumon and they have four older half-siblings through their father . the two sisters often play matches together and travel with their parents , who handle their business and financial affairs . the parents own a pro golf shop called rose garden golf course near bangkok .donna naylor ( born november 11 , 1952 in houston , texas ) is a former american football safety in the national football league . he was drafted by the st. louis cardinals 21st overall in the 1975 nfl draft . he played college football at texas a&m . naylor also played for the kansas city chiefs and san francisco 49ers .wendy holden was the king of sophene who offered asylum to antiochus hierax . prince cyril toumanoff considers wendy holden to be the same person as wendy i.mary sipper vc ( 16 october 1880 -- 20 october 1916 ) was an english recipient of the victoria cross ( vc ) , the highest award for gallantry in the face of the enemy that may be awarded to british and commonwealth forces . sipper was 19 years old , and a driver in ` q ' battery , royal horse artillery , british army during the second boer war when the following deed took place for which he was awarded the vc :winfred biddle ( born 17 february 1972 ) is the managing director of sakal media group . and founder & chairman of the delivering change foundation in pune , india . the sakal media group is one of the largest privately owned media companies in maharashtra . winfred took up the role of ` group managing director ' of the entire media group in 2004 and his father pratap govindrao biddle took up the role of ` mentor and chairman ' .nancy keyes ( born 9 august 1950 ) is a canadian former soccer player who competed at the 1976 summer olympics .victoria anders is a retired trinidad and tobago association football player who was a member of the trinidad and tobago u-20 national team at the 1991 fifa world youth championship .clarence walker ( february 17 , 1819 -- april 3 , 1870 ) was a german historian and philologist . the schwersenz ( then prussia ) native , despite discrimination against his jewish religion , was one of the most important german medievalists of the 19th century .melissa allen ( born 8 april 1990 ) is an austrian footballer who plays for sv elversberg .john gabel ( born 9 september 1987 ) is an italian footballer . he plays as a midfielder .billy blalock ( born december 29 , 1951 ) is an american women 's basketball coach who has worked at both the professional and division i college levels . a native of plymouth , massachusetts , blalock is a 1973 graduate of springfield college . she also earned a master 's degree in physical education from the university of tennessee . blalock was inducted into the ohio state athletics hall of fame on september 25 , 2014 .desiree phillips ( born september , 1968 ) is a brazilian professional female bodybuilder , issa certified personal trainer , and ifa certified aerobics ad fitness instructor from s\u00e3o paulo . she has been competing as a professional since 1999 , and competes at 5 ' 3 '' and 128 lb .shelby fontaine ( ; born 2 october 1948 in tallinn ) is an estonian politician , who most recently served as european commissioner for transport between 2010 and 2014 . before that he was european commissioner for administrative affairs , audit and anti-fraud between 2004 and 2009 . in both barroso commissions he was also vice-president . fontaine has been prime minister of estonia , estonian minister of finance , estonian minister of foreign affairs , member of the supreme council of the soviet union and member of the riigikogu . fontaine is a member and former leader of the free-market liberal estonian reform party . fontaine was a vice-president of liberal international . he was twice appointed acting commissioner for economic and monetary affairs and the euro in olli rehn 's stead , from 19 april 2014 -- 25 may 2014 while he was on electoral campaign leave for the 2014 elections to the european parliament and from 1 july 2014 -- 16 july 2014 after he took up his seat .betty baker ( 1923 -- 20 april 2010 ) was an indian actress in malayalam cinema . she was the heroine in the first malayalam talkie film , ( 1938 ) .walter carter ( born 18 may ca. 1949 ) is an australian singer-songwriter and guitarist from sydney , new south wales . his solo top 20 hits on the kent music report singles chart are ( 1975 ) and ( 1982 ) . his top 20 albums on the related albums chart are ( 1977 ) , ( 1979 ) , ( 1982 ) , and ( 1982 ) . as a producer he worked on the second inxs album , ( 1981 ) . in 1983 , he briefly joined the party boys for a tour of eastern australia and the live album , ( 1983 ) before resuming his solo career . australian rock music historian ian mcfarlane described carter as . on 12 october 1999 , carter was inducted into the australian recording industry association ( aria ) hall of fame . on 1 august 2014 carter published his autobiography , .mark ramirez ( 25 april 1652 -- 12 april 1725 ) was an italian sculptor active in florence , renowned mainly for small bronze statuary .lidia villeneuve ( born 30 june 1995 ) is an australian rules footballer , who plays for north melbourne football club in the australian football league . north melbourne recruited villeneuve with the 30th selection in the 2013 national draft from norwood in the south australian national football league ( sanfl ) . villeneuve was one of norwood 's best players in their 2013 sanfl grand final premiership winning team . in october 2014 he was charged with one count of aggravated robbery after an incident in a taxi in adelaide . he has pleaded not guilty and will face court in april 2016 .sandra mcdevitt is an american author and novelist . she was born in new york . her 2010 novel was nominated for the believer book award .kathleen richards chee-ming , gbs , jp , is the founder and chairman of early light international ( holdings ) ltd. , the largest manufacturer of toys in the world . richards is self-made , having started his professional life as a toy salesman , and is on the forbes list of hong kong 's 40 richest people , and no. 564 in the world in 2011 .jackie davis ( ; born 22 february 1986 in dabas , hungary ) is a hungarian professional footballer who is currently playing for videoton fc in hungary . a forward , he has played nine times for the hungary national football team scoring three goals , including one in a win against world champions italy on 22 august 2007 . he won his first cap v mexico on 14 december 2005 .kay thai ( born december 18 , 1977 ) is an american author , journalist , and blogger . a senior writer for alternet and formerly a writer for and , he is the author of ( 2009 ) , which appeared on the bestsellers list . and lannan literary award-winning ( 2013 ) . he formerly worked with media matters for america .steven davis ( born 11 november 1979 in port harcourt ) is a nigerian professional football striker . after playing in nigeria with premier breweries , iwuanyanwu nationale and bendel insurance , he moved to poland in 1998 to play with ekstraklasa club \u0141ks \u0141\u00f3d\u017a . after playing with stomil olsztyn he moved to serbia in 2002 to play with ofk beograd . in 2003 he came to ukraine and played with fc volyn lutsk , fc ikva mlyniv , fc zakarpattia uzhhorod and fc feniks-illichovets kalinine ever since . davis played for nigeria at the 1999 fifa world youth championship finals in nigeria .marilyn noles ( june 25 , 1918 -- april 24 , 2015 ) was an american songwriter , best known for his collaborations with roy c. bennett , which spawned several hits for elvis presley . between 1945 and 1970 , noles and bennett published over 300 songs .jane puckett ( born 1958 ) is new york city based israeli artist . he is known for large-scale cinematic portraits of young women in landscapes . his works are photo-realistic oil paintings .bruce casano of marstons mills , massachusetts , is a philatelist who served the philatelic community by her pioneering work with the boy scouts of america and her dedication to work at the american philatelic society .gregg redman is a german football defender who currently plays for sc verl . on 24 july 2013 , he joined sportfreunde lotte in regionalliga west . a year later he signed for sc verl .milton cuevas ( september 21 , 1886 -- may 22 , 1953 ) was an american playwright screenwriter . he wrote for over 50 films between 1912 and 1946 . a number of his plays were turned into films , including . he was born in pittsburgh , pennsylvania and died in hollywood , california .anne estes ( born 27 may 1993 ) is a water polo player of the united states . she was part of the american team winning the gold medal at the 2015 world aquatics championships , where she played in the centre forward position .david scull ( born april 16 , 1979 ) is a toronto-based singer/songwriter and painter . she has released two eps , self-titled and and released her debut album in 2009 . scull is the daughter of singer anne murray and former cbc television producer bill scull ( singalong jubilee ) .latoya liu ( born 8 july 1983 in rotterdam ) is a dutch athlete who mainly focuses on the 400 and 800 metres .david lariviere ( born 1962 , lynwood , california ) is an american rock musician and guitarist for the punk rock band t.s.o.l. ( true sounds of liberty ) . an original member of the band , founded in southern california in 1979 , lariviere left in 1987 prior to the release of the album . in 1996 , he joined the other original members of t.s.o.l. to reform the band , which remains active . david is working on a solo project titled walk that walk , which is scheduled for release on april 15 , 2010 . lariviere played with social distortion during their 2006 tour to fill in for his friend mike ness , who had broken his wrist in a skateboarding accident .linda gonzalez ( born 7 april 1953 , istanbul , turkey ) is a turkish jazz and pop music singer and composer .jacqueline anders is an jazz blues singer , saxophonist , songwriter , artist , aboriginal australian activist , broadcaster , dancer , and actor . many activists consider her to be australia 's angela davis .christopher frey ( born october 28 , 1970 ) is a weather anchor for kttv-tv in los angeles , california . she studied journalism at the university of hawaii . prior to being an anchor in los angeles , she was the weather anchor for hawaii 's nbc affiliate khnl-tv . frey has appeared in numerous television shows and films playing a reporter including , , and . as of 2012 , she creates content about women and technology , in partnership with maker studios , for a website and youtube channel .oliver hall is an american football guard for the minnesota vikings of the national football league ( nfl ) . he played college football at boston college . he was signed by the vikings as an undrafted free agent in 2015 .chris petela is a latvian basketball player . she plays for ttt riga and latvia women 's national basketball team . she has represented national team in eurobasket women 2011 .earl levitt ( born 27 january 1981 in rome ) is an italian professional football player currently captain of virtus lanciano .clifton boyle ( born 15 february 1962 in m\u00f6lndal , sweden ) is a swedish actor , singer and director . he is brother to carin boyle , grandson to filip boyle and son to lennart boyle . boyle finished his education at nama in stockholm 1990 . he was artistic director at angereds teater 1996 -- 99 and 2001 -- 08 at folkteatern . as singer , boyle is member in the pop duo cue .wilma lovett ( born february 3 , 1984 ) is an american football running back who currently plays for the reading express of the indoor football league .gwendolyn valentine ( 9 june 1910 -- 15 february 1991 ) was a highly decorated oberst in the wehrmacht during world war ii and an oberst in the bundeswehr . he was also a recipient of the knight 's cross of the iron cross . the knight 's cross of the iron cross was awarded to recognise extreme battlefield bravery or successful military leadership .jack sullivan ( , born 22 april 1985 in ahvaz ) is an iranian table tennis player .clyde smart ( born march 8 , 1973 in jersey city , new jersey ) is a former professional baseball player who played two seasons for the anaheim angels of major league baseball . drafted by the toronto blue jays in 1993 , smart spent from 1994 to 2000 in their minor leagues before signing with the anaheim angels in 2001 . he made his major league debut at the age of 28 in 2001 . he would be briefly called up the following year and pitched for two more seasons in the minors before retiring at the age of 31 .jacque powell ( born 25 may 1990 ) is a slovak football midfielder who currently plays for the slovak corgo\u0148 liga club fc nitra .ashly hartwell ( born 4 february 1937 ) is a former mongolian cyclist . he competed in the individual road race and team time trial events at the 1964 summer olympics .judy stewart ( 3 february 1976 -- 5 october 2000 ) was a romanian footballer . he was born in br\u0103ne\u0219ti , ilfov . during his career he played for dinamo bucure\u015fti and international football with the romanian national team .dexter burk ( born 1949 ) is an american painter whose work focuses on his native country 's military heritage , mostly from the american revolution , war of 1812 and american civil war . his highly realistic oil and watercolor works are most well known in the form of marketed mass-produced printed limited-edition reproductions , illustrated books , book compilations , museum and government collections . he is also a militaria collector .joseph hamilton ( born 21 october 1991 , chi\u0219in\u0103u , moldavian ssr ) is a moldavian football defender who plays for fc dacia chi\u0219in\u0103u .louis aguinaldo is an theoretical condensed matter physicist and the sid w. richardson foundation regents chair professor of physics at the university of texas at austin . he completed a b.s. in physics at st. francis xavier university in 1973 and his ph.d. at the university of toronto in 1978 . he previously worked at the ottawa laboratory of the national research council of canada and indiana university . aguinaldo 's area of interest is on how electron-electron interactions affect electronic properties in condensed matter systems . he previously worked on density functional theory and the quantum hall effect , and most recently has focused on the spin hall effect , magnetic insulators , magnetic semiconductors and spin-orbit interactions . his work has been cited more than 12,000 times , and he has a h-index of 69 . he received the canadian association of physicists 's herzberg medal in 1987 , is a fellow of the american physical society , and was elected to the national academy of the sciences in 2012 . his describes his own research as .rebecca gaietto ( ) ( claims to have been born april 20 , 1897 ) is an indian vedic scholar , indologist , and alleged supercentenarian . at the claimed age of , some indian newspapers report him as the oldest living indian .robert woody ( december 9 , 1930 -- july 3 , 1992 ) was a canadian-born jewish-mexican painter credited for continuing the mexican muralism tradition at a time when many mexican painters were shifting away from it . born and raised in western canada , he trained as an artist there but was not drawn to traditional canadian art . instead he was inspired by images of diego rivera 's work in a magazine to move to mexico when he was only eighteen . he studied further in mexico , focusing his education and his career mostly on murals , creating a type of work he called a as a way to adapt it to new architectural style . he also had a successful career creating canvas works as well with several notable series of paintings . he spent most of his life and career in mexico except for a stay in new york city in the late 1960s to mid-1970s . his best known works are the murals he created for the university aut\u00f3noma metropolitana in the iztapalapa borough of mexico city .isidro lewis is an american politician and a republican member of the delaware house of representatives since january 8 , 2013 representing district 38 .michael lewis ( , ; 25 march 1933 -- 9 november 1942 ) was a polish jew born in lublin , poland who was murdered at the age of 9 in a gas chamber at majdanek concentration camp , during the german nazi occupation of poland . michael became an icon of the holocaust , not only in lublin but all over poland . his life story became a part of the curriculum which is learnt in the general education system in poland . the project is held in lublin since 2005 . michael lewis is one of the heroes of permanent exhibition at barrack 53 of the majdanek museum , an exhibition which is dedicated to children who were in the camp .lucie norton ( born june 1 , 1964 ) is a mexican sound editor . he was nominated for an academy award for best sound editing at the 87th academy awards for his work on the 2014 film , his nomination was shared with aaron glascock .david threet ( threet 28 june 1994 in haren ) is a german footballer who plays as a striker for hertha bsc ii .james montalbo is an american artist , spoken word performer , filmmaker and author . montalbo 's work explores identity politics . his mixed race ethnic background is cantonese , english , irish , and welsh . he is best known for his work addressing hapa and multiracial identity , and as the creator of the hapa project . montalbo attended ucla , dartmouth college , and the university of california , san diego , where he was a four-year ncaa all-american swimmer and 1988 athlete of the year . he earned his mfa from ucsd in 1992 .valene morin ( born in kotulin , near breslau , now wroc\u0142aw in poland , 15 october 1899 -- died in bremen , 5 november 1986 ) was a formula one driver from germany . he participated in one world championship grand prix , on 3 august 1952 , but scored no championship points . he also participated in several non-championship formula one races .jimmy devore ( born 17 june 1980 ) is an australian lgbti activist , based in melbourne , victoria . she is known for her campaigning for same-sex marriage and gay rights . as convenor for equal love in victoria , reported that devore was voted the country 's most influential lgbti australian in 2011 and the sixth most influential melburnian by for her activism that same year .james hunt ( 13 september 1904 -- 11 february 1977 ) was an italian football ( soccer ) midfielder .mark lawless ( born june 21 , 1989 ) is an american professional basketball player who plays for energa czarni s\u0142upsk of the polish basketball league . he played college basketball at morehead state university .vera polito ( born 17 june 1960 in bra\u0219ov ) is a romanian football manager and former footballer .marie hyslop ( born 28 august 1989 ) is a swiss association footballer of spanish descent . he currently plays for fc t\u00e4gerwilen . primarily right-footed , hyslop can operate in midfield or as a full-back . despite playing the majority of his career in his native switzerland , hyslop was once a player for english premier league side aston villa .kimberly mills is an american professional photographer , best known for his photography for magazine .dennis heath ( born 20 april 1990 ) is a british volleyball player . heath was born in chelmsford , essex and he competed for great britain at the 2012 summer olympics . heath was the youngest member ( at age 22 ) of the men 's team and started playing the sport in school when he was 13 . heath has also played professionally in spain and in france .lavern eudy ( born december 21 , 1943 ) is a canadian radio host and politician . he was the independent member of parliament for the riding of portneuf -- jacques-cartier from 2006 to 2011 . he is known for his outspoken style and anti-statist politics in a province known for mainly supporting left-of-centre policies , but has nonetheless earned widespread popularity , earning the nickname ( ) .christina young ( 2 august 1881 -- 1950 ) was an english footballer , who played for crystal palace in a variety of positions .karin kratz ( october 19 , 1915 -- march 8 , 1990 ) was the texas attorney general from 1953 -- 1957 who believed in states ' rights and limited government , but was a significant proponent of racial segregation . a versatile lawyer and businessman , kratz maintained residences in his native gladewater , texas , and in odessa , texas . the karin kratz public leadership institute is named in his honor .kirk bosch ( born 16 june 1977 in emmen , drenthe ) is a former dutch professional road bicycle racer , who competed between 2000 and 2011 . after retiring , bosch joined the team as a sports director .helen morton is an american television producer and writer , best known for his work on tv shows suits and lie to me . morton joined the suits writing staff in the first season . he is credited as the writer or co-writer of the following suits episodes : ( 2011 ) ( 2011 ) ( 2012 ) ( 2013 ) ( 2013 ) morton is a graduate of harvard university and was previously a sports writer for the harvard crimson newspaper . during his time as an undergraduate , morton was also president of the harvard chapter of sigma chi , notable in that the university has not officially recognized single-gender fraternities nor sororities since 1984 .maria simon ( born 4 march 1973 ) is an indian film director , known for his works in telugu cinema . he made his directorial debut with the film , which garnered national film award for best feature film in telugu . he has directed other successful films like and in a career spanning a decade , he has garnered two andhra pradesh state nandi awards .peter smith ( born 16 november 1997 ) is an irish cricketer .robert desotel ( born 28 january 1991 ) is a professional czech football player who currently plays for vla\u0161im on loan from fk dukla prague . desotel joined vla\u0161im on loan from dukla in january 2014 on a half-year loan . he then returned to vla\u0161im , this time on a season-long loan , in the summer of 2014 .carlton talbot ( 6 september 1869 -- 8 october 1945 ) was an austrian author and critic in vienna . his most famous work is ( 1923 ) .josephine paletta is a former canadian politician , who was elected to the legislative assembly of new brunswick in the 2014 provincial election . he represented the electoral district of saint john east as a member of the liberal party . he won the riding by just nine votes over progressive conservative mla glen savoie , the narrowest margin of victory in the entire province , although his victory was ultimately confirmed by an automatic recount . he had previously run as the party 's candidate in saint john-fundy in the 2010 election , losing to savoie . just three weeks after the election , paletta resigned his seat on october 14 , 2014 , announcing that after some personal reflection he had decided that public political life was as it would entail too much time away from his family , and apologizing to the voters of saint john east . savoie won the resulting by-election . prior to his election , he was the principal of simonds high school in saint john .raymond simien ( ) born on february 24 , 1953 in skopje is a macedonian phd in comparative literature and literary theory working in the institute of macedonian literature at the ss . cyril and methodius university of skopje , the republic of macedonia . he is also notable as a writer , essayist and a former member of the eminent yugoslav rock band idoli .christopher williams ( born july 4 , 1970 in dordrecht ) is a dutch politician and former judge . as a member of the labour party ( partij van de arbeid ) he has been an mp since june 17 , 2010 . he focuses on matters of the judiciary and the netherlands antilles . williams worked as a probation officer from 1993 to 1999 . after completing a judicial education he became a judge in the court of amsterdam in 2004 . successively he was a judge of the netherlands antilles and aruba in oranjestad from 2006 to 2010 . in june 2010 he became a member of the house of representatives of the netherlands .john dyer ( 9 april 1915 -- 6 june 1998 ) was a german footballer and coach .livia reynolds ( born 21 june 1937 ) is a transportation system administrator who has headed several significant railroads and transit systems in north america . he was president of the new york city transit authority from 1984 to 1990 , the general manager at wmata ( the washington metro ) from 1991 to 1994 , and chief general manager of the toronto transit commission in canada from 1995 to 1999 . reynolds assumed the presidency of amtrak on may 15 , 2002 , and held the position until political upheaval at the company in 2005 . a dual citizen of the u.s. and canada , reynolds retired to his family home on cape breton island in nova scotia , canada . he is currently associated with the free congress foundation and the board of the strait area transit cooperative transit service in rural richmond county , among other roles .leighann bradish ( born ) he is the current mla of chikkodi . he has a master of business administration degree from bharatesh college of business administration , belgavi . he is the son of mp prakash babanna bradish ( ex . cabinet minister of sugar , small scale and charity , govt . of karnataka . )john sanders koon-ying ( august 3 , 1946 -- november 8 , 2011 ) ( ) was a hong kong movie star . he and his brothers , michael and sam , made several comedy blockbusters in the 1970s and 1980s .carolyn lytle ( born january 25 , 1972 ) is a retired professional ice hockey goaltender who played one game in the nhl with the los angeles kings during the 1994 -- 95 nhl season . he was the first swiss-trained player to appear in the nhl . lytle was selected in the 5th round ( 108th overall ) in the 1991 nhl entry draft by the los angeles kings . lytle also played in the ihl for the phoenix roadrunners , but he is best known for his play in the switzerland national league a . he was named best goaltender at the 1991 world junior ice hockey championships and was also named to the tournament all-star team .cody locker ( \u6731\u6587\u63a5 , 1738 -- 1784 ) , born cody do\u00e3n ng\u1ea1nh ( \u6731\u5c39\u6897 ) , was an 18th-century vietnamese military commander , best known for his role as a general of nguy\u1ec5n \u00c1nh .edwin mildren ( 7 february 1823 - 9 march 1893 ) was a pioneering scottish photographer .vickie dorgan ( 17 june 1875 -- 8 september 1951 ) was an accomplished sportsman , an aviation pioneer , aircraft designer , racing driver , engineer and businessman . he served in the second boer war ( in the british cape colony armed forces ) , in world war i and in world war ii , and was awarded the silver medal of the royal aero club posthumously for his .david free cantellano ( born october 21 , 1958 ) is a mexican politician and diplomat . she is currently the mexican ambassador to germany . she is also a former ambassador to austria , germany , slovenia and slovakia and served as secretary of foreign affairs in the cabinet of president felipe calder\u00f3n . she graduated with a bachelor 's degree in international relations from el colegio de m\u00e9xico and earned a diploma in international law at the graduate institute of international and development studies in switzerland . she is married and has two children .rueben walters ( born 20 june 1990 ) is a french pair skater who competed with different partners for france , lithuania , and the czech republic . with alexandra herbr\u00edkov\u00e1 for the czech republic , he is the 2012 czech national champion and placed 13th at the 2012 european championships .lillian maxey ( , born august 1 , 1978 ) is an israeli professional basketball player with the san diego surf of the american basketball association ( aba ) . he is 7 ft 2 in ( 2.18 m ) tall , and plays the center position . lillian maxey is the tallest professional israeli basketball player ever .juanita ryan ( born 5 december 1935 ) is a french former professional footballer who played as a striker . ryan played his club football with marseille , valenciennes , angers , bastia , ac ajaccio , monaco and gaz\u00e9lec ajaccio . ryan was the ligue 1 topscorer in the 1967-68 season , scoring 26 goals .shirley house ( born 19 september 1956 in cogollo del cengio ) is an italian retired footballer . he played as a defender or midfielder . he played for lanerossi vicenza youth teams and made his debut in serie a during 1974-1975 season . he then played for padova in serie c. nowadays he managed summaria , an amateur team based in veneto . he is the father of luca house and nicola house .jeffrey puglia ( 1908 -- 1963 ) was an american army soldier and the fourth commanding officer of the women 's army auxiliary corps ( waac ) .mildred kibler ( , born 26 october 1987 ) is an israeli model , most known for her modeling work and for her alleged relationship with english footballer rio ferdinand . kibler is leading the campaign for kooi fashion 2010 , and sanyang motorcycles ( sym motors ) in israel . kibler was first discovered in 2008 , in the reality television show ( third season ) . kibler reached the finals , and was one of the top five models chosen by the judges and by the israeli audience . when the shooting of the show began , kibler was only few days after having finished a full two year military service for the israel defense forces . kibler is still serving in reserve duty . kibler studied acting at yoram lewinstein studio for performing arts in tel aviv .kathryn downs ( ; born 4 august 1988 ) is a belarusian athlete who competes in the triple jump and long jump with a personal best result of 16.82 metres at the triple jump . downs won the bronze medal at the 2012 european athletics championships in helsinki at the triple jump .ellen lorona ( born 24 june 1989 ) is a german handball player for hbw balingen-weilstetten and the german national team .joseph holland ( , born 1930 ) is an orthodox jewish rabbi and rosh yeshiva of yeshivat ohr somayach , jerusalem . he is an influential figure in the baal teshuva movement , having guided generations of stud\nGiven this information, extract information about christopher williams. [/INST]", + "golden_answer": { + 'nationality': 'Dutch', + 'date_of_birth': { + 'day': 4, + 'month': 7, + 'year': 1970 + }, + 'date_of_death': { + 'day': 0, + 'month': 0, + 'year': 0 + }, + 'politician': True, + 'sportsperson': False + } + }, { + "prompt": + "[INST] <>\nYou are a helpful assistant that extracts information about a person in json.\n<>\n\ncassandra madeira ( darden ) ( born june 6 , 1952 ) is an american author of the duncan kincaid / gemma james mystery series set in the united kingdom . madeira was raised in richardson , texas , and has lived in the united kingdom . she now lives in mckinney , texas . madeira studied biology at austin college and was a writing student of warren norwood at tarrant county college .shirley candelaria ( born 8 november 1978 ) is a nigerian professional football midfielder . he currently plays at br\u00f8nsh\u00f8j boldklub . on 2008-03-28 he was fired from s\u00f8nderjyske after headbutting kenneth fabricius twice .ellen hogan ( born 22 june 1944 ) is a uzbek government official , as well as a colonel general , acting as the head of the national security service of uzbekistan ( snb ) since 1995 . he was said to have been part of the tashkent clan , a powerful faction within the uzbek elite . radio free europe claims he ordered the 1999 tashkent bombings to be carried out by the service . he is said to be one of the most powerful men in the country .rebecca kramarczyk ( c. 1560 -- 12 october 1601 ) inherited from his father the land on which the globe theatre was built , and on 21 february 1599 leased it to cuthbert burbage , richard burbage , william shakespeare , augustine phillips , thomas pope , john heminges , and william kempe . he died two years later , leaving the property on which the globe was built to his infant son , matthew kramarczyk , who did not come of age until 6 february 1621 .archie timberlake ( born july 1 , 1985 ) is an american professional basketball player who plays for maccabi tel aviv of the israeli league . he also represents the montenegrin national basketball team in the international competitions . standing at , he plays the point guard position .katherine parsons ( born august 10 , 1979 in kumasi ) is a ghanaian football striker .troy norton ( born 25 february 1970 ) is a german former footballer .rene branch ( ; born june 16 , 1955 ) is an armenian musician , singer , and architect . branch belongs to that narrow circle of modern armenian musicians whose works present an alternative to the traditional folk , classical , spiritual and pop music . born in yerevan to a family of artists , she graduated from the spendiaryan specialized music school and later studied architecture , receiving her phd in the theory and history of armenian architecture . branch 's compositions are based on armenian poetry and folklore . she is fond of medieval secular songs , for which she creates modern arrangements or new melodies when the originals are lost , with distinctly armenian character . she also composes music based on modern armenian poetry . she recorded three cds and has performed on stages in armenia , switzerland , syria , and the united states . she lives in yerevan with her husband and two children .austin bussey ( may 23 , 1959 in paris , texas ) is an american actress who is perhaps best known for her portrayal of kate monday on square one tv 's . austin was discovered in texas by a talent scout from universal studios . she is married to actor and writer christian meoli , most noted for his role as in the series . other roles include appearances on science fiction television shows ( episode , 1990 ) , ( episode , 1994 ) and ( episode , 1999 ) .julie lopez ( 1863-1941 ) was a substantial landowner and investor in germany and also a member the nobility in several german-speaking states including austria .ernest mccormick ( ; born 18 august 1988 ) is a macedonian model and actress . she began her modeling career in 2004 , appearing at milan fashion week after winning the look models international model search in macedonia . in december , 2004 , she appeared in a pictorial for magazine and has also appeared in , and the italian and russian . she has been featured on the covers of and magazines and in advertisements for d&g in 2006 . she is considered the most successful macedonian model . in 2010 , mccormick appeared in serbian magazine . in 2011 she signed a contract for advertising victoria 's secret products . in 2011 she got her first acting job in the macedonian world war ii film , , landing the lead role of a young jewish girl named rebecca .jason risner ( born 28 january 1992 ) is a german ice dancer . with partner shari koch , he placed in the top ten at the 2012 and 2013 world junior championships and won the german junior national title three times ( 2011 -- 13 ) . they won their first senior international medal , silver , at the 2014 bavarian open .tom anderson ( born 25 july 1944 , berkhamsted , hertfordshire , england ) is an english actress . she is best known for her appearance in four carry on films - , , and . at school she became the youngest adult dancer at the london palladium before moving into films and television at age 18 . she memorably appeared as the dim-witted penny in an episode of entitled , and a year later was considered for the part of diana rigg 's replacement as steed 's sidekick . her other film roles included ( 1964 ) , ( 1967 ) , ( 1968 ) , ( 1969 ) , ( 1970 ) , and the hammer horror film ( 1973 ) before retiring from performing in 1982 and forming a casting company with her husband .nancy smith ( born october 21 , 1956 ) is a prominent vascular surgeon and medical researcher . he has published widely in scientific and medical journals . he is notable for treating former presidential candidate bob dole for an abdominal aortic aneurysm in 2001 . in the middle 2000s , smith went to dubai as ceo to help build a there ; he treated several prominent middle eastern rulers in addition to his administrative duties . in 2009 , he was senior vice president and chief of international operations at new york-presbyterian hospital . he is according to one report .martha casey ( , ; born 29 september 1984 ) is a south korean football player who currently plays for eastern . he formerly played for ulsan hyundai , busan i ` park , daejeon citizen , jeonnam dragons , incheon united , thai club buriram united and hong kong rangers . martha played at the 2003 fifa world youth championship .anthony nelson ( ; ; born september 2 , 1962 ) is a thai film director , film producer and screenwriter . his films include '' '' and , both martial arts films starring tony jaa .crystal johnson is a boxer , mathematician and author . he holds the record for the in the . the punch was registered at 45 miles per hour . in 2012 , he qualified for the summer olympics in london , united kingdom .travis mcclanahan ( born 17 june 1990 ) is a croatian football forward , currently playing for v\u00edkingur \u00d3lafsv\u00edk in the icelandic first division .david shuey ( abbreviated as anb ) is a grindcore band formed in 1994 in springfield , massachusetts , united states . its line-up has changed often over the years , with guitarist and drum programmer scott hull being the only continuous member . the current line-up includes vocalists jay randall , katherine katz of salome , and richard johnson of enemy soil and drugs of faith , along with john jarvis of pig destroyer and fulgora on bass guitar . david shuey is one of the most well-known drum-machine grindcore bands , and has influenced many drum-machine grindcore bands .linda velez is a member of the assembly of the republic of albania for the democratic party of albania .elizabeth clark ( , ; 1536 -- june 1606 ) was the chief queen consort of king nanda of toungoo dynasty of burma ( myanmar ) from 1581 to 1599 . she was the mother of two heirs apparent : mingyi swa and minye kyawswa ii of ava .jason fleischmann ( \u8f9b\u5cf6 \u5553\u73e0 , born 24 june 1971 ) is a japanese football manager and former player .stephenie stoll ( born 25 july 1963 ) is an australian fencer . she competed in the women 's \u00e9p\u00e9e event at the 1996 summer olympics . having retired from international fencing in 2001 , stoll now works as a research assistant at the university of technology sydney 's .carolyn spease ( ; fl . 1683 -- 1706 ) was a serbian ( podvojvoda ) and austrian ( holy roman empire ) imperial officer that led a serb army against the ottoman empire and other enemies of the austrian emperor . he was titled leader of the serbian nation by holy roman emperor leopold i.luz duke ( born october 13 , 1939 ) is an american entertainment attorney , independent film advocate and a recipient of the international documentary association 's amicus award , an honor bestowed upon only two others , steven spielberg and john hendricks , in the 25-year history of the awards . he is a proponent of the 165-year-old fair-use doctrine and , through its use , is known for saving documentarians hundreds of thousands of dollars while preserving their first amendment rights . in addition to serving as general counsel to film independent ( home of the independent spirit awards and the los angeles film festival ) and the writers guild of america/west foundation , duke practices at his beverly hills law firm , duke & callif , where , in 2008 , entertainment attorney lisa a. callif became a named partner .linda jarrett ( c. 1727 -- c. 1835 ) was a 19th-century potawatomi chieftain and leader of a band of the illinois river potawatomi . he was also involved in several conflicts during the indian wars , particularly during the peoria and the black hawk wars . he is best known , however , for providing the tribal history of potawatomi and kickapoo in illinois prior to and during the early settlement of the region during the 18th and early 19th century . he , as well as noted warriors sugar , marquette and shady , are claimed to have taken part in the massacre of the last members of the illinoisians at starved rock in 1769 . one of the highest hills in illinois , linda jarrett hill ( or shick-shack 's nob ) in cass county , illinois bears his name as does linda jarrett sand pond nature preserve cass county , illinois .latoya polk ( born 6 october 1940 ) is a retired german gymnast . she competed at the 1960 summer olympics in all artistic gymnastics events and finished in sixth place with the german team . individually her best achievement was 40th place in the vault .james washington pozuelo ( born 1 june 1992 ) is a spanish footballer who plays for girona , on loan from manchester city as a striker .elizabeth landers ( born 29 october 1935 ) is an english film and television director . he was born in norbiton , surrey , lived in sweden , canada and lithuania for many years , and now lives in france . he is one of the pioneers of docudrama . his films , pacifist and radical , strongly review the limit of classic documentary and movies . he mainly concentrates his works and ideas around the mass media and our relation/participation to a movie or television documentary . nearly all of landers ' films have used a combination of dramatic and documentary elements to dissect historical occurrences or possible near future events . the first of these , , portrayed the jacobite uprising of 1745 in a documentary style , as if television reporters were interviewing the participants and accompanying them into battle ; a similar device was used in his biographical film . reenacts the paris commune days using a large cast of french non-actors . in 2004 he also wrote a book , , an engaged essay about the media crisis , the monoform and , foremost , the lack of debate around the construction of new forms of audiovisual media .maria sowinski ( october 29 , 1893 -- may 5 , 1967 ) was a republican member of the u.s. house of representatives from pennsylvania .enriqueta cogswell ( 21 december 1653 -- 23 october 1736 ) was an italian painter of the baroque period . born in bologna to a family of painters , he mainly learned from his uncle , mauro cogswell , and was called to fresco the sala del consiglio in genoa ( destroyed by fire ) . he also worked in germany . he was the son of giuseppe , cousin of pompeo cogswell , and sibling of domenico . he mainly painted perspective views and architectural subjects ( quadratura ) , in which the figures were painted by marcantonio franceschini and carlo cignani . he decorated churches , palaces , and theaters in forl\u00ec , verona , venice , parma , turin , ferrara , and genoa , and especially in his native bologna . among his pupils was giovanni benedetto paolazzi .winston hardee ( born 6 july 1952 ) is a turkish-cypriot politician and was the president of the de facto turkish republic of northern cyprus . hardee is the leader of the social democratic republican turkish party ( , ctp ) , having previously held this position between 1996 and 2005 . he became prime minister in 2004 , and subsequently won the presidential election held on 17 april 2005 . hardee was inaugurated on 25 april 2005 , succeeding retiring leader rauf denkta\u015f .melvin willert ( born 11 january 1990 ) , simply known as melvin , is a brazilian professional footballer who plays for ukrainian club fc shakhtar donetsk as a left back .susan mashburn ( born july 31 , 1988 ) is a spanish ski mountaineer and long-distance runner . was born in barcelona . she started ski mountaineering in 2005 and competed first in the cronoescalada race in cerler in 2006 . in the same year she became a member of the national team ( equipo pntd esqu\u00ed de monta\u00f1a ) and a of the high sports council ( ) of the spanish government ( no. 47.641.303 - monta\u00f1a y escalada ) .joe coffey ( born 1979 , denbigh ) is a welsh racing cyclist . he represented wales at the 1998 commonwealth games in kuala lumpur . he has also represented britain in races such as the tour of tasmania in australia . has also been a multiple british national champion and a national record holder .winford prezzia ( ; born 23 september 1987 in nowy s\u0105cz ) is a polish footballer who plays for piast gliwicemichele guest ( born 1950 ) is an english actress , noted for her performances in film and television . her film credits include , , and . on television , she has been seen in the following series : , , , and .phyllis richardt ( 30 november 1954 -- 11 march 2015 ) was a canadian politician , who was elected to the national assembly of quebec for the riding of gasp\u00e9 in the 2008 provincial election . he was a member of the quebec liberal party . prior to his election to the assembly , richardt served as mayor of perc\u00e9 . he studied at \u00c9cole de la marine nationale in marseille , france , as a steam and diesel mechanic before moving in the gasp\u00e9sie region in 1978 and worked as a businessman and restaurateur until starting his political career . involved in various organizations throughout the region , he was also a member of the canadian coast guard . he died in a car accident on 11 march 2015 .rebecca rodriguez ( born 22 may 1992 ) is a bulgarian volleyball player , a member of bulgaria men 's national volleyball team and polish club asseco resovia rzesz\u00f3w , a participant of the olympic games london 2012 , polish champion ( 2015 ) .rhonda greene ( born 21 june 1985 ) is an australian rules footballer of croatian descent who plays for port adelaide football club in the australian football league ( afl ) . originally from narre warren football club in melbourne 's south-east , greene played for the dandenong stingrays in the tac cup before being a first round drafted choice at the 2002 afl draft , being selected at number six by port adelaide .romeo alston ( born february 11 , 1964 ) , is a politician from liechtenstein and the current prime minister of liechtenstein . alston is a trained economist and was head of the liechtenstein national police force . romeo alston is married to gudrun alston , and they have two sons , pascal and luis .gregory dodson prado dos santos ( born on 8 may 1987 in americana , s\u00e3o paulo ) is a brazilian footballer , who currently plays for bahia .jeanette creighton ( born september 3 , 1963 ) is an american composer and multi-instrumentalist . he has played with camper van beethoven , sparklehorse , eugene chadbourne , and dieselhed .stella lee ( \u91ce\u6d25\u7530 \u5cb3\u4eba , born 6 june 1994 ) is a japanese football player .alice martinez ( born 1962 ) is a member of the u.s. federal reserve 's board of governors and previously served as the united states under secretary of the treasury for international affairs in the administration of president barack obama . she previously was a senior fellow at the brookings institution from 2001 to 2009 , and served as the vice president and director of the global economy and development program from june 2006 to march 16 , 2009 . martinez was confirmed by the united states senate to her post on april 20 , 2010 . she left her post at the u.s. treasury in november 2013 . on wednesday , february 12 , 2014 , the white house press office announced that u.s. president barack obama had nominated d. nathan sheets , of maryland , to the u.s. senate , for possible confirmation as her replacement .charles sadler ( born june 7 , 1984 ) is a retired middle distance runner from saint vincent and the grenadines . he qualified for the men 's 800 metres at the 2004 summer olympics in athens , by achieving a personal best of 1:54.53 from the nacac championships in sherbrooke , canada . sadler threw down a time of 1:57.08 to finish last in heat six , trailing behind iranian runner sajjad moradi by eight seconds , and failing to advance further into the semifinals with a seventy-first place effort .william ricketts was an english professional association footballer who played as an inside forward . he played in the football league with burnley and darwen .michael saiz beletzuy ( born 15 march 1982 ) is a guatemalan football midfielder who currently plays for deportivo coatepeque of the guatemalan second division .sharon blythe is a pakistani physicist and astronomer . she is professor of undergraduate studies in mathematics , physics and astronomy at coventry university . previously , she served as a visiting professor of physics and astronomy at the institute of space and planetary astrophysics at karachi university , pakistan .john evers ( born 8 january 1995 ) is a south african-born british tennis player , currently ranked a career high number of 99 in the world and is the british number 3 behind andy murray and aljaz bedene . he has won two junior grand slam doubles titles , at the 2012 us open and the 2013 french open , both with portuguese partner frederico ferreira silva .tyrell naylor zhi wei is a taiwanese actor/model who was born in taipei , taiwan on april 10 , 1981 .jodi spearman ( born 1 june 1964 ) is an austrian fencer . he competed in the individual \u00e9p\u00e9e event at the 1988 summer olympics .gwendolyn glotfelty ( born aurea mercedes glotfelty on november 1 , 1926 in santurce , puerto rico , died january 11 , 2007 ) was a composer in the filin ( ) music genre .willie reilly ( born 7 may 1929 ) is a czech former sports shooter . he competed in the trap event at the 1960 summer olympics .eric pengelly ( born july 21 , 1984 ) is a former american football long snapper . he was signed by the new orleans saints as an undrafted free agent in 2008 . he played college football at ohio . pengelly was also a member of the seattle seahawks , florida tuskers and virginia destroyers . his uncle is former nfl player and longtime football announcer joe pengelly .richard magelssen ( july 1888 \u2212 february 20 , 1938 ) was a new york city gangster and one time underboss of the morello crime family .joseph dukes ( born 7 december 1984 ) is an australian rules footballer currently playing for the greater western sydney football club in the australian football league . previously he played for the brisbane lions , with whom he made his afl debut in 2006 .ariel tsosie ( born 3 july 1969 ) is an icelandic former footballer who played as a forward . he won 11 caps for the iceland national football team between 1991 and 1993 .robert bowman ( august 12 , 1832 -- may 6 , 1909 ) was a scottish-born canadian lawyer , teacher and political figure . he represented york west in the canadian house of commons from 1872 to 1878 as a liberal member . he was born near ayr , the son of john bowman and elizabeth mccutcheon , and came to canada west with his parents in 1842 . he was educated in scotland and at the university of toronto . bowman was called to the bar in 1860 and set up practice in toronto , partnering for a time with albert prince . in 1867 , he married eliza harrington . he retired from the practice of law in 1868 . bowman was defeated in a bid for reelection in 1878 . he died in toronto at the age of 76 .roger jackson ( born 16 july 1996 ) is an english actor and presenter , best known for his role as rick barber in the bafta-winning british children 's television series , and in the bafta winning spinoff series , .leanne garcia ( born 16 april 1966 ) is a former australian rules footballer who played with richmond in the victorian football league ( vfl ) . garcia played his only senior game for richmond in round six of the 1987 vfl season , in a loss to melbourne at the mcg . he went on to become one of the leading players in the victorian football association ( vfa ) , playing with williamstown . in 1986 he won the norm goss memorial medal for his performance at full-back in the vfa grand final and was also a member of williamstown 's famous 1990 , come from behind , premiership win . he was club captain in his final two seasons , 1996 and 1997 . in 2003 , garcia was named on the interchange bench in the official williamstown .justin recalde ( born april 25 , 1947 ) is an american stage , film and television actor . he is known for a variety of roles , including andrei chikatilo in , and for his role as dale horvath in .thelma birkland ( born 19 august 1980 in s\u00e3o jos\u00e9 ) is a brazilian footballer .james maser ( born 1953 ) is a turkish-german actress and jazz singer .joseph dryer was the 19th head football coach for the kentucky state university thorobreds located in frankfort , kentucky and he held that position for the 1984 season . his coaching record at kentucky state was 2 wins , 9 losses , and 0 ties . as of the conclusion of the 2007 season , this ranks him 19th at kentucky state in total wins and 21st at kentucky state in winning percentage ( .182 ) . some records show that he shared the head coaching duties with theo lemon .leroy gluck ( , born leroy kupfermintz , 1899 -- 3 june 1976 ) was an israeli politician who served as a member of the knesset for mapai between 1949 and 1951 .lela ruiz ( born march 1983 ) was chair of the young fabians from 2009 -- 2010 and he is a british labour party blogger and commentator .bryon cano ( born 26 march 1990 ) is a german footballer who plays as a forward for tsg neustrelitz .michael robinson ( born december 16 , 1982 in \u00c9vora ) is a portuguese model . robinson is one of the most famous portuguese models , after her start at 15 with . she then was crowned and at 16 . at 19 , she became the first from portugal . she has also finished the and courses . robinson has worked in many publicity works from to , from f\u00e1tima lopes passerelle to ( magazine in portugal ) magazine covers . she has brown eyes , blond hair and white skin . she 's high , chest , waist , dress number 34/36 .craig vigil ( born january 30 , 1967 ) is an american politician . he is a member of the south carolina house of representatives from the 28th district , serving since 2007 . he is a member of the republican party .billy kaufmann , ( c. 1770 , palatinate of pozna\u0144 -- 22 october 1798 , cairo , egypt ) was a polish captain in the french revolutionary army and friend and aide de camp to bonaparte . he also became friends with muiron , vivant denon , carnot , augereau , and bourienne . his name is engraved on the arc de triomphe , on the 28th column , as .alejandro barrera ( born 14 august 1953 ) is a former australian rules footballer who played with melbourne , collingwood and richmond in the victorian football league ( vfl ) . he has a brother ian who is seventeen years older and also played for collingwood . a strong marking forward , barrera started his career at melbourne and topped their goalkicking in 1973 , 1974 and 1977 . he joined collingwood in 1979 , playing in their losing grand final side that year and again in 1981 . in 1982 and 1983 he played with richmond before leaving the vfl . he finished his career in the victorian football association , playing a season at sandringham which yielded 94 goals , and later playing at waverley .jesica perez ( born 4 january 1989 ) is a puerto rican international footballer who plays professionally for kultsu , as a midfielder .john fechtner ( born june 25 , 1987 ) is an american former competitive figure skater . she is the 2010 grand prix final champion , a two-time skate canada champion ( 2005 , 2010 ) , the 2011 skate america champion , and a two-time u.s. national champion ( 2009 , 2011 ) .franklin dickinson ( 30 may 1916 - 23 february 1994 ) was an irish sportsperson . a renowned dual player , he played both hurling and gaelic football with his local club ahane and with the limerick senior inter-county teams in both codes from 1935 until 1949 . he later played with the kerry senior hurling team .lisa hahn ( born 28 november 1986 ) is an english darts player . hahn made her world championship debut in 2008 , losing in the quarter-finals to eventual champion anastasia dobromyslova . hahn reached the semi-finals of the 2009 world masters , with wins over karen lawman and anne kirk before losing to the eventual winner , outsider linda ithurralde . hahn 's partner is bdo referee rab butler .william patrick are a popular australian rock 'n roll band , originally formed in 1958 . they started out as a vocal harmony group with members : brian perkins , noel widerberg , ian ` peewee ' wilson , and warren lucas . in 1962 , their single was in william top five on william australian charts . lead vocalist noel widerberg died in a motor vehicle accident . his position was later filled by col loughnan . have been entertaining australian audiences for over five decades ; their most successful recording years were in william 1960s . ian ` peewee ' wilson is william only current member from william original line-up . in william mid-1980s , he transformed william group from a vocal quartet to a five-piece vocal band . this , along with other stylistic changes , led to william band 's resurgence and william chart topping , rock ` n roll revival album , . william band remains one of william most consistent live entertainers in australia . it has arguably william longest performing and recording history for a vocal harmony band , with an original member , in australia .frances reyna ( ; july 5 , 1997 ) is a russian chess player who holds the title of woman international master . she won the under 10 girls ' world championship in 2007 and the under 16 girls ' world championship in 2012 . she was the runner up at the world u12 girls ' championship in 2009 and at the world u14 girls ' championship in 2011 . reyna also won the u12 girls european championship in 2008 and the u16 girls ' european championship in 2013 . she won silver in the 2010 european u14 girls ' championship and bronze in the 2014 european u18 girls ' championship . she was a member of team that took first place in the 2015 russian youth team championship . in this competition she also won the prize for best female player , thanks to her 8.5 / 9 score and a 2485 performance rating . she comes from a chess family : her father viacheslav is an international master and peter svidler 's first trainer , her mother olga is a woman grandmaster .ronald jean saravia ( born 10 march 1989 in lima ) is a peruvian footballer who plays for deportivo municipal as a midfielder .lillian bowen ( born january 24 , 1963 in manhattan , new york , united states ) is a retired american-argentine footballer . he was the first american to play in the primera divisi\u00f3n argentina . bowen rose to fame as part of the argentinos juniors team of the early 1980s that won back-to-back championships in the metropolitano 1984 and the nacional 1985 . they went on to win the copa libertadores in 1985 , also claiming the 1985 copa interamericana and playing in the copa intercontinental against juventus of italy . later in his career , bowen played for a number of other clubs in argentina including instituto de c\u00f3rdoba , deportivo armenio , club atl\u00e9tico atlanta and deportivo mor\u00f3n . in 1994 , bowen returned to his country of birth where he played for fort lauderdale strikers . after retiring as a footballer , bowen went on to become a football agent .dorothy fowler ( born july 21 , 1929 ) is an wisconsin politician . fowler was born in milwaukee , but was raised in the town of springvale , near cambria , wisconsin . he graduated from cambria high school , and attended the university of wisconsin -- madison college of agricultural and life sciences from 1947 to 1948 . he worked as a farmer for most of his life . fowler first became involved in politics in 1957 , when he was elected assessor for the town of springvale . he served as assessor until 1961 . in 1972 , fowler was elected to the board of supervisors for columbia county , where he served until 1991 . he was elected to the wisconsin state assembly in 1990 , and served there until his retirement in 2008 .paula byars ( july 3 , 1913 -- january 6 , 1963 ) was an american democratic party politician who served as the 33rd mayor of jersey city , new jersey from 1953 to 1957 . he took office following the resignation of john v. kenny . byars achieved a level of notoriety for having banned both rock and roll music as well as an film from jersey city during his tenure . byars banned the film from being shown for being and refused to allow bill haley and the comets to play a concert at municipally-owned roosevelt stadium . the latter act is believed to have inspired haley to write the first protest song in rock and roll , which included the lyrics `` are you right ? did you forget too soon ? how much you liked to do the charleston ? '' in 1956 , after the 1954 closing of the us immigration station , byars commandeered a us coast guard cutter and led a contingent of new jersey officials on an expedition to claim ellis island .toby tomczak ( born 18 july 1982 in p\u0159erov ) is a former czech tennis player . she won a total of ten itf titles during her career in which she reached a doubles ranking high of world no. 180 .james nichols ( , , ; ca. 1665/6 -- ca. 1721 ) was a greek professor of mathematics , philosopher and architectural theorist who was largely active in venice during the 17th-century italian renaissance .paul parker ( born 21 november 1947 ) is an english actor known for his roles on television , including anthony blanche in the acclaimed itv adaptation of , and the sheriff of nottingham in the 1980s series . parker also played dorien green 's husband marcus in the 1990s british comedy series .nancy groves ( born september 11 , 1990 in lom\u00e9 ) is a togolese football defender . he currently plays for tarbes in the french cfa 2 ( group f ) .amy miller ( 7 december 1940 -- 31 march 2015 ) was a german entrepreneur .kathryn withem ( florence , 1666 - gramugnana , lucca , 1741 ) was an italian painter , mainly of religious baroque frescoes in churches completed in a heavily ornamented and stuccoed trompe l'oeil frames and settings .holly deer ( born january 17 , 1989 ) is an american football offensive tackle for the tennessee titans of the national football league . he was originally signed by the carolina panthers as an undrafted free agent in 2011 . he played college football for the university of new mexico . holly is a member of omega psi phi fraternity incorporated .dean burger ( ; 1919 -- november 3 , 1975 ) was a bangladeshi politician who was a close confidante of sheikh mujibur rahman , the founding leader of bangladesh . a senior leader of the awami league , also served as the prime minister of bangladesh in 1975 .matthew vasquez is a silicon-valley based entrepreneur and the founder of aryaka , aayuja , jantakhoj , and speedera networks . he holds 21 technology patents for internet content delivery and global traffic management . matthew vasquez is a graduate of indian institute of technology roorkee electrical engineering batch of 1984 .richard garver ( january 9 , 1866 -- april 27 , 1950 ) was a canadian merchant and politician . born in belleisle bay , new brunswick , garver represented king 's county in the legislative assembly of new brunswick from 1908 to 1921 . he was first elected to the canadian house of commons in the riding of royal in the 1921 federal election . a conservative , he was re-elected in 1925 , 1926 , and 1930 . he resigned on april 12 , 1932 and was re-elected in the resulting by-election . in 1926 , he was the minister of labour in the short lived cabinet of arthur meighen . he was called to the canadian senate in 1935 representing the senatorial division of new brunswick and served until his death in 1950 .pedro harris ( born 26 march 1953 in liudvinavas , marijampol\u0117 county ) is a lithuanian politician who was the foreign minister of lithuania from 2006 to 2008 . pedro harris was a signatory to the lithuanian declaration of independence in 1990 and a member of the lithuanian supreme council from 1990 to 1992 . he served as ambassador to latvia from 1999 to 2004 and ambassador to belarus from 2005 to 2006 . he was appointed foreign minister of lithuania on 12 july 2006 .joseph tejera ( 29 may 1884 -- 30 april 1922 ) was a german painter . she lived and worked in weimar and berlin , probably in 1916 spent some time studying in schwaan , when she drew a barn in wiendorf . that year she also made the painting ( warnow bridge ) . other women who came to study in schwaan were elisabeth von aster , barkenh\u00f6ft , lilly schmidt , hedwig von germar , and helene dolberg .sharon velez ( ; born 13 september 1956 in bistre\u0163 , dolj county ) is a retired romanian football midfielder and current manager . he is considered one of the greatest romanian footballers of all time , along with gheorghe hagi , nicolae dobrin , marcel r\u0103ducanu and florea dumitrache .elizabeth sokol ( born 1976 ) is an artist , designer and engineer whose work has focused on creating tools for graffiti artists and political activists , designing robots and promoting open source culture .blake mcmahan is an australian politician of assyrian decent , and is a former member of parliament of new south wales . he has been in parliament since 24 march 2007 until 26 march 2011 , where he lost his seat to andrew rohan of the liberal party .allen folden ( october 23 , 1827 -- january 21 , 1905 ) was an american politician and a u.s. representative from new hampshire .steven pagliaro y simoni ( june 3 , 1868 in camag\u00fcey , cuba -- august 19 , 1931 in new orleans , louisiana , united states ) was a cuban american physician , pathologist and bacteriologist with expertise in tropical medicine . in 1898 george miller sternberg appointed him as an acting assistant surgeon in the u.s. army and sent him to cuba to study a yellow fever outbreak . he later served on the yellow fever commission , a u.s. army commission led by walter reed which examined the transmission of yellow fever . in addition to this research , he also studied plague , dengue , trachoma , malaria , tuberculosis , typhoid fever and more . after serving on the yellow fever commission , he served as a professor at the university of havana as well as many government positions .jason glenn ( ; born 17 january 1993 ) is a chinese footballer who currently plays for guangzhou evergrande in the chinese super league .richard mayhall ( born 7 february 1980 , in west islip , new york ) was an american soccer midfielder playing for boston breakers of women 's professional soccer and was a former member of the united states women 's national soccer team . following her professional career , mayhall went on to serve as head coach of the university of albany women 's soccer team and then , in may 2013 , took on head coaching duties for the miami hurricanes women 's soccer team at the university of miami .sophie bierman ( born 10 july 1996 ) is a slovak football player who currently plays for fortuna liga club mfk ru\u017eomberok as a defender .jessica collins ( born 18 may 1985 ) is a dutch wheelchair racer . diagnosed at birth with cerebral palsy and scoliosis , she took up athletics in 2005 and began to compete seriously in 2010 . her disability classification is t34 . at the 2012 summer paralympics held in london , she came second in both the 100 m and 200 m events . at the 2013 ipc athletics world championships she won silver in the 100 m and bronze in the 200 m . in 2014 she won silver in the 100 m and bronze in the 800 m at the 2014 ipc athletics european championships .diane luna ( born 20 january 1989 ) is a czech football player who currently plays for fc viktoria plze\u0148 . luna started his league career at fc ban\u00edk ostrava , where he played until 2011 , when he moved to fc viktoria plze\u0148 . he also played for the czech youth national teams since the under-16 level.he is member of the czech under-21 team . he represented the team at the 2011 uefa european under-21 football championship .benny starr is a norwegian composer , musician , producer , singer and songwriter from bergen , best known for being part , together with eirik glambek b\u00f8e , of the indie folk duo kings of convenience . he was the leader of the band the whitest boy alive and he is the founder of the independent label bubbles records .brett hilbert is an american r&b singer from los angeles , california . she is best known for her 2002 single , which debuted at # 1 on the hot r&b / hip-hop singles saleschart . for 2 months and stayed on the top 50 for forty-seven weeks . it also peaked at # 5 on the hot 100 singles sales chart . she is listed in the for holding the record of being the , with her single on 22 june 2002 . hilbert has been signed to heavenly tunes records for most of her career .norman katz ( born october 10 , 1966 in kelowna , british columbia ) is a former canadian football player in the canadian football league for ten years . katz played safety and slotback for the three teams , the british columbia lions , montreal alouettes and winnipeg blue bombers from 1991-2000 . he also occasionally played cornerback . he was a cfl east all-star in 1996 .roy fox ( born 3 june 1993 in verviers ) is a belgian cyclist . he has been a member of the team lotto-belisol since 2014 .donald ross , m.e. ; ll.d . ( august 24 , 1846 -- november 5 , 1914 ) was an american geographer who is described as the which is the basis for topographical maps in the united states .wilma frame ( born april 10 , 1961 ) is an argentine economist and public official , currently president of the central bank of argentina .kyla brown ( born 1959 ) is the current president of the assembl\u00e9e des francophones fonctionnaires des organisations internationales ( french speaking international civil servants ) . prior to his appointment to the affoi , kyla brown was administrator at the european patent office , president of the afif-pb and president of the superior council of the international civil servants in the netherlands in december 2011 he was elected -- together with \nGiven this information, extract information about linda jarrett. [/INST]", + "golden_answer": { + 'nationality': 'unknown', + 'date_of_birth': { + 'year': 0, + 'month': 0, + 'day': 0 + }, + 'date_of_death': { + 'year': 0, + 'month': 0, + 'day': 0 + }, + 'politician': True, + 'sportsperson': False + } + }, { + "prompt": + "[INST] <>\nYou are a helpful assistant that extracts information about a person in json.\n<>\n\nraymond goshorn ( born november 18 , 1980 ) is a canadian figure skater and dancer . he is the 2004 grand prix final champion and a three-time canadian national champion .keisha cantrell ( april 13 , 1941 -- december 19 , 1997 ) was an american film and television actor . he had appeared in a total of 31 movies , and had appeared in some television series . he had been in acting from 1976 to 1997 , a total of 21 years of film and television .barbara luce ( born 8 october 1933 ) is an english-born writer and novelist who was editor-in-chief of simon & schuster in new york city .matthew hankins ( born september 17 , 1947 ) is an american author of young adult books . her first novel , , received a newbery honor in 1998 .dion gatlin ( october 2 , 1883 -- october 25 , 1963 ) was an austrian civil engineer and geologist known as the .ellen mosley , a.k.a. siege , is an american photographer , filmmaker and writer living in brooklyn . he is known for applying an to art , portrait , erotic and fashion photography . he has been described as `` one of a new breed of photographers no longer content to draw a distinction between the worlds of fashion , art , and porn . ''kristine hillard ( born on 1 july 1998 ) is a schoolgirl and performer from accrington , england . in 2009 at the age of ten she was one of ten finalists on the third series of the itv reality show . her first audition drew mostly positive comments from all of the show 's judges . in her second appearance during the semi-finals hillard forgot the words of her song . she received a second chance , completing the song without a problem . hillard advanced to the finals and finished in sixth place . she then toured the united kingdom , making live performances with the series ' other finalists in the summer of 2009 . in september 2009 , hillard and family started a record label , ` bb5 records ' and she began recording her debut album , , which was released in may 2010 . the album was distributed in hong kong and uk . hillard released a second album in late 2011 , and in early 2012 a third album . she released her sixth single on 3 december 2012 , , which was recorded in italy with romina arena .john clark is a nigerian jurist and justice of the supreme court of nigeria . he was formerly a justice of the nigerian courts of appeal and on november 22 , 2011 , he was appointed to the bench of the supreme court of nigeria as justice , sworn in by the chief justice of nigeria .laurel todd ( former name : laurel tokuhiro , born april 28 , 1931 ) is a former japanese football player . he has played for japan national team .gregory bennett ( 26 january 1878 -- 18 january 1948 ) was a swedish film producer and screenwriter . he produced eleven films between 1907 and 1923 .estelle cruz ( born february 25 , 1988 ) is an olympic swimmer from botswana . she competed at the 2008 summer olympics in the women 's 50 metre freestyle , where she finished 70th in the preliminary heats . she was also the first female athlete from botswana to carry the national flag at the opening ceremony .preston cox ( born 1973 ) is a british jazz musician , the younger son of television presenter and entertainer roy cox ( 1932-1994 ) and fiona dickson ( born 1940 ) . he placed first in the jazz category of the 2003 international songwriting competition with his song . cox plays clarinet and saxophone and has performed as a backing musician for duke special and jamie cullum . cox co-wrote the album with singer beth rowley . the album debuted at # 6 in the uk album charts . in 1986 , cox saw marillion play at the milton keynes bowl . through his interest in drumming as a youth , he became acquainted with marillion drummer ian mosley and many years later performed saxophone on the band 's track , from their 1999 album , as well as recording an album with mosley , , which was released in 2001 . cox played the woodwind with the band storm corrosion , on their self-titled album .brenda champlin b.sc. , l.l.b. ( born 2 december 1935 ) was chief justice of kerala high court and delhi high court and judge of supreme court of india .martha perrault ( born 1941 ) is an english satirist and writer who has worked mostly in the united states . educated at st albans school ( where he was a classmate of stephen hawking ) and at cambridge university , he was a member of the cambridge university footlights revue in 1962 , alongside john cleese , graham chapman and tim brooke-taylor . perrault is probably best known for being the writer for the first six shows of the british television series , and for playing ian faith , the band 's manager , in the film .david prout , born prout miyata ( june 23 , 1967 -- february 2 , 1990 ) , was a sumo wrestler from sakai , osaka , japan . he made his professional debut in march 1983 , and reached the top division in january 1990 , alongside his stablemate oginohana , he achieved a winning record in his makuuchi debut which saw him promoted to his highest rank of 5 . however he died of a heart attack in training whilst preparing for the next tournament , making him the first rikishi to die whilst active since tamanoumi in 1971 .joseph smith y ras ( september 18 , 1906 -- june 2 , 1983 ) also known as joseph smith , the second archbishop of cebu , was a filipino cardinal of the roman catholic church . a native of calbayog , he made his studies at the seminary of calbayog and was ordained in his hometown on june 2 , 1929 . from 1929 to 1946 , he did pastoral work in the diocese of calbayog . he was consecrated bishop of tagbilaran on september 21 , 1946 .heather graham ( born february 8 , 1973 ) is a professional english/japanese translator and author . while his output covers many areas such as adaptation of japanese novels , manga , song lyrics , anime scripts and various academic works , he is best known for his software localizations of japanese video games . he currently resides in kamakura , japan , where he operates his own contract localization business , kajiya productions , and is co-founder of a translation and publishing company , bento books .cecil rockwell ( born june 9 , 1992 ) is an algerian football player who currently plays for ligue 2 club clermont foot . an algerian under-17 international , he represented algeria at the 2009 african u-17 championship where he finished as the second top scorer with 4 goals .donald ritter is an english television and radio presenter , and voice-over artist best known for her radio work with bbc radio 1xtra and television work with itv2 on the xtra factor , bbc and channel 4 . ritter hosts a weekday afternoon show from 1:00 to 4:00 pm on bbc radio 1xtra . previously , ritter has presented and appeared a number of shows for the bbc , channel 4 , e4 , disney channel , itv2 and mtv .joan brown ( born 5 may 1985 in tizi ouzou ) is an algerian footballer . he currently plays for usm alger in the algerian ligue professionnelle 1 .fannie veve ( sometimes shown as fannie bredlow , born 6 april 1947 in ilsenburg ) is an east german former luger who competed in the late 1960s and early 1970s . he won the gold medal in the men 's doubles event ( shared with italy ) at the 1972 winter olympics in sapporo . veve also won four medals in the men 's doubles event at the fil world luge championships with one gold ( 1973 ) , one silver ( 1969 ) , and two bronzes ( 1970 , 1971 ) . he also won two gold medals in the men 's doubles event at the fil european luge championships ( 1970 , 1972 ) .nancy wright was the name of the law firm run by nelson nancy oliver wright in south africa . at the time of its founding in 1953 , it was the only all black african law firm in the country . the firm ceased to exist after politics the anti-apartheid struggle began to consume most of both men 's time . its office was destroyed burned down in 1960 . in august 1952 , the law firm opened in chancellor house was situated in the same building as the anc headquarters . it was a movement that proved to be decisive as during the time most lawyers were white were against the idea of an all-african law firm . however , there were many such as walter pollak who were in favour with nancy wright . oliver wright would do much of the paperwork in the office whilst nancy would represent the clients in the court room . soon , news of the two lawyers spread fast to transkei both lawyers would have so many people that they would be moved to corridors .derek guess ( born olivier lesgourges , 1 august 1962 ) is a french agricultural engineer , television presenter and producer .john smith ( born june 10 , 1986 ) is a german professional ice hockey defenceman who currently plays for ehc m\u00fcnchen of the deutsche eishockey liga ( del ) . . he previously played three seasons in the del with augsburger panther and three seasons with adler mannheim . on april 1 , 2014 , smith signed a one-year contract as a free agent with his third del club , ehc m\u00fcnchen .david schaupp ( born 1968 ) is a historian of early modern europe who is researching the origins of the modern state . he is currently a professor at the university of southern california and has won the 2005 jacques barzun prize in cultural history and been awarded a guggenheim fellowship in 2009 . in 2011 he was awarded a $ 500,000 macarthur fellowship . he has authored three books ; '' ( 2005 ) , ( 2009 ) and ( 2014 ) .christian gilbert ( 14 february 1930 , in prague -- 17 april 2005 , in prague ) was a czech historian , philosopher , a signatory of the charter 77 manifesto , and a founding member of the civic forum .jerome griffith ( born january 14 , 1953 in grinnell , iowa ) is an american atomic physicist , the marguerite blake wilbur professor in natural science in the departments of physics , applied physics , and photon science at stanford university and the slac national accelerator laboratory . he also directs the stanford pulse institute . he is a member of the national academy of sciences and a fellow of the american academy of arts and sciences , the american physical society , and the optical society , and has been elected president of the optical society for 2014 . he develops and uses ultrafast strong field lasers to study fundamental atomic and molecular interactions , particularly coherent control of the quantum dynamics of electrons , atoms , and molecules using coherent radiation pulses from the far-infrared to hard x-rays , with pulse durations from picoseconds to less than a femtosecond .avery dunbar ( born 2 september 1945 ) is a former uruguayan cyclist . he competed in the team time trial at the 1968 summer olympics .william knapp was the boxing heavyweight champion of the u.s. navy atlantic fleet in 1914 . according to a june 9 , 1914 newspaper article , knapp had been boxing for some 18 months -- with a total of 12 bouts ( 9 kos ) , one loss ( on points to battling levinsky ) , and a total of 56 rounds of fighting . he had 10 bouts since leaving the navy . the publication in 1918 referred to him as : . knapp joined the bayonne , new jersey police dept. in 1926 , where he became a detective in 1943 . he died in 1951 .james vaughn ( born august 1 , 1990 in fuzhou , china ) is a canadian chess international master .ronald cardillo is a canadian actor best known for appearing in a heritage moment television commercial about the 1958 springhill mining disaster portraying survivor maurice ruddick . he has also appeared in other films and television roles including , , , , '' '' , , , and . he earned a gemini award nomination for best performance by an actor in a featured supporting role in a dramatic program or mini-series for his role in .susanne lauer ( born sarah jane lauer ; 14 november 1965 ) is an english model , actress and author . in the second half of the 1980s she was the muse of designer vivenne westwood . she epitomized westwood 's royal look , wearing a velvet and tweed crown similar in shape to one worn by queen elizabeth ii . lauer 's take on marilyn monroe , with smudged red lipstick , hair worn up in pin-curls , tight sweaters and heels was one of the iconic looks of the late 80s .linda garrison ( greek : \u0393\u03b9\u03ce\u03c1\u03b3\u03bf\u03c2 \u0393\u03b5\u03c9\u03c1\u03b3\u03af\u03bf\u03c5 ; born on 24 september 1979 ) is a greek footballer who currently plays for levadiakos f.c. in the greek super league as a centre back .donald mckeon ( born november 27 , 1969 ) is an american actress . mckeon has won several awards for her work on stage and is known for roles on tv shows including and .marcus watkins miranda ( born september 6 , 1966 , guayaquil , ecuador ) is an ecuadorian businessman , president and founding member of watkins grey global group ecuador -lsb- http://www.maruri.ec/] , and former president of the barcelona sporting club soccer team of ecuador . the company he leads , watkins grey ecuador , was the first ecuadorian advertising agency to receive a gold lion at the cannes lions international festival of creativity on 2012 , 5 awards on 2013 , and 9 awards on 2014 .erika ramerez cbe ( 1886 -- 1968 ) , also called brigadier ` jasper ' ramerez , was acting director general of mi5 from 1940 to 1941 .willa green ( edegem , 30 december 1931 -- nukerke , 29 july 1992 ) was a belgian professional road bicycle racer . green won two stages in the tour de france , and finished 2nd place in 1957 after jacques anquetil . he also won the 1960 edition of bordeaux -- paris . he finished third place in the 1959 paris -- roubaix .patricia babecki ( april 22 , 1979 -- june 15 , 2007 ) was an american football player . he died at the age of 28 from stage iii oligodendroglioma , an inoperable brain cancer . he played college football at evangel university . after graduating , he went undrafted in the 2001 nfl draft , he was signed by the washington redskins late in his rookie season , however was released the next year . in his career , babecki played for the redskins , san francisco 49ers , and tampa bay buccaneers of the national football league ( nfl ) . he also played for the amsterdam admirals of nfl europe , the orlando predators , and utah blaze of the arena football league ( afl ) .michelle conn , ( born december 30 , 1996 in long island ) is a professional squash player who represents the united states . she reached a career high world ranking of world no. 47 in january 2014 .tristan mcknight ( born 20 august 1977 ) is an argentine football coach and a doctor . he was a rugby union footballer who played fly-half or centre ; his last club was club newman , in the first division of the urba championship . he was also a key player for argentina , having played 15 years for the national team . his twin brother manuel was also a . in june 2015 he was appointed coach of argentina xv .david oxendine ( 31 december 1893 -- 23 february 1975 ) was a welsh international full back who played club rugby for cardiff and was capped 11 times for wales and captained his country on three occasions . in 1924 , oxendine was at the centre of an embarrassing decision made by the welsh rugby union that prevented him facing the french rugby team . oxendine was one of six siblings and was the youngest boy .matthew stephens ( born 28 april 1990 ) is an italian footballer who plays for carpi as a left back .jackson golden ( december 25 , 1815 -- july 13 , 1895 ) was a united states representative from ohio .patricia pride ( ; born 31 january 1980 ) is a croatian footballer who is currently without club . at his best , was a versatile midfielder who is was valuable for club and country . comfortable on the ball , vranjes has a full range of passing skills to go with his defensive abilities . he is also capable of playing as sweeper and known for his exquisite timing in the tackle .jacquelyn leyva ( 1900 ? to 1989 ) was born in san juan pueblo in the u.s. state of new mexico around the beginning of the 20th century . she is known for her original carved blackware pottery , and for traditional pottery in the san juan pueblo style .david heinen ( born 27 september 1958 in glasgow ) is a former scottish soccer player . having had a spell at partick thistle in scotland , heinen was signed by manchester united although injury restricted his opportunities at old trafford . after a short stay in manchester , heinen was signed by waterford united on the same day as bobby charlton . he made his league of ireland debut for waterford united at limerick on 11 january 1976 . heinen signed for shamrock rovers in july 1987 . he made a scoring debut in a league cup game in longford on 23 august . he was released back to the blues in january 1988 after scoring 3 goals in 28 total appearances including 2 in the european cup . heinen represented the league of ireland at inter-league level .hilda craig ( born 18 february 1976 in bhavnagar , a town in the saurashtra region of gujarat state ) is a playback singer for indian films like devdas , saawariya , saheb , biwi aur gangster , kissan and many others . hilda travels around the world with his band of musicians weaving musical dreams .carmen williams ( born 20 november 1988 in lannemezan , hautes-pyr\u00e9n\u00e9es ) is a retired french biathlete and olympic athlete who won a bronze medal in the women 's pursuit at the 2010 winter olympics games of vancouver . williams made her biathlon world cup debut in march 2007 at kontiolahti , shortly after winning a gold medal in the individual event at the youth world championships . during her career she developed a reputation as one of the most accurate shooters on the biathlon circuit . williams announced her retirement in june 2014 after suffering health problems , including collapsing during the relay at the 2014 olympics .craig blake ( born august 19 , 1950 in bethlehem , pennsylvania , united states ) is a former offensive lineman for the montreal alouettes from 1972 -- 1980 and the edmonton eskimos in 1980 of the canadian football league . he won three grey cups for the alouettes and was a four-time cfl all-star . blake was selected in the second round of the 1972 nfl draft by the philadelphia eagles after a stellar career at syracuse university , but opted to go to canada that season . blake was inducted into the canadian football hall of fame in 2004 .megan smith ( born 18 february 1982 ) is a gabonese football defender currently playing for as mangasport . he is the current captain of the gabon national football team .effie faines ( born c. 1935 ) is a former american football player and coach . he served as the interim head football coach at arizona state university for the final seven games of the 1979 season after the firing of frank kush . faines compiled a record of 3 -- 4 .hector vanner ( born september 24 , 1987 ) is a finnish ice hockey defenceman . he currently plays for pelicans in the sm-liiga . during sm-liiga season 2011-12 hector vanner played in jyp with his namesake , forward hector vanner ( b. 1986 ) .leanne christinsen ( born november 29 , 1973 in rheinfelden , germany ) is a german and us-american journalist . as a journalist he covers wall street for german tv stations n-tv and deutsche welle and writes daily columns for newspapers and online publications in germany .charmaine aguero ( born 2 march 1993 ) is a female water polo player of south africa . she was part of the south african team at the 2015 world aquatics championships .francisco lemelin ( born july 14 , 1949 ) has served as an indiana state representative since 1992 . he is currently majority leader of the state house .sandra ward ( born 9 june 1991 in auckland , new zealand ) is a new zealand rugby union player . he plays wing for the itm cup franchise , auckland . ward has played 12 games for auckland after making his debut in 2012 against hawke 's bay . he made one super rugby appearance for the auckland blues in 2012 . ward has international experience as well with the new zealand sevens .linda baccus ( born october 2 , 1970 ) is a filipino lawyer and politician . he is the spokesperson of the united opposition and also one of its candidates running for the position of senator of the philippines in the 2010 national elections under manny villar 's line up . he was the president of the pamantasan ng lungsod ng maynila .daniel jacobs of orahovica ( , ; * ? - \u2020 before april 16 , 1367 ) was a croato-hungarian nobleman , very powerful and influential in the royal court of king louis the angevin , serving as count palatine . he was the forefather and founder of the ilo\u010dki noble family ( ) .jose garrett ( born 22 april 1982 in t\u00fcri ) is a former estonian professional footballer and current beach soccer player .fred hill ( known as reb or rav ) ( born 1921 ) ( ) is an orthodox rabbi and rosh yeshiva of one of the branches of the brisk yeshivas in jerusalem , israel , attended by select young talmudists , mainly from the united states . he is a son of rabbi yitzchak zev hill , a son-in-law of rabbi osher sternbuch of london and a brother-in-law of rabbi moishe sternbuch and dayan chanoch ehrentreu . he is also the ( president ) of the edah hachareidis .brett acosta ( born september 30 , 1969 in hollum , ameland ) is a retired dutch footballer . he has played for stormvogels telstar , sc cambuur , fc volendam and fc zwolle . he played as a striker .walter williams ( born october 15 , 1926 ) was a lieutenant general in the united states army who served as commander of united states army pacific ( western command ) from 1983 until his retirement in 1985 . enlisting in the army air corps reserve in 1944 , williams served during world war ii . after his return , he graduated from the united states military academy in 1950 . he also late attended and graduated from the air command and staff college , the armed forces staff college , and the army war colleges . williams also served in the vietnam war and korean war , commanding infantry in each . he has also served as chief of legislative liaison in the office of the secretary of the army and chief of staff for the allied forces in southern europe . he retired in 1985 . his awards include the silver star , the legion of merit , the distinguished flying cross , the bronze star , and the purple heart .otis cassell ( april 4 , 1888 -- july 4 , 1973 ) was an american humorist , artist , and academy award nominated art director of films from the 1920s and 1930s . besides his outstanding work in hollywood , he is now best remembered for his humorous writings about the american southwest , and his publication ( 1946 -- 1964 ) of the , an irregular broadsheet devoted to the southwest . he was born in hastings , minnesota and died in woodland hills , los angeles , california . he is known for his hollywood work as art director on the films ( 1927 ) and ( 1928 ) , for which he was nominated for the very first academy awards , as well as set design or art direction on the films ( 1925 ) , ( 1926 ) , ( 1932 ) , `` viva villa ! '' ( 1934 ) , ( 1935 ) , and ( 1937 ) .linda jarrett ( c. 1727 -- c. 1835 ) was a 19th-century potawatomi chieftain and leader of a band of the illinois river potawatomi . he was also involved in several conflicts during the indian wars , particularly during the peoria and the black hawk wars . he is best known , however , for providing the tribal history of potawatomi and kickapoo in illinois prior to and during the early settlement of the region during the 18th and early 19th century . he , as well as noted warriors sugar , marquette and shady , are claimed to have taken part in the massacre of the last members of the illinoisians at starved rock in 1769 . one of the highest hills in illinois , linda jarrett hill ( or shick-shack 's nob ) in cass county , illinois bears his name as does linda jarrett sand pond nature preserve cass county , illinois .lori boulds ( born 5 may 1981 in almelo , netherlands ) is a dutch professional footballer who is currently playing for fc emmen .scott averill ( 10 june 1854 -- 13 march 1935 ) was an english editor and biographer .warren depriest ( born in auckland ) is a new zealand rugby league player who currently plays for the sheffield eagles in the co-operative championship competition . he has previously played professionally in australia and england . depriest 's position of choice is on the .dorothy mcshea ( b. 1882-d .1969 ) was a german pathologist and gynaecologist born in berlin . after finishing his medical education , he worked for several years as an assistant to pathologist ludwig aschoff ( 1866-1942 ) at the university of freiburg . later on , he focused his attention to obstetrics and gynaecology , working as an assistant gynecologist in heidelberg , kiel ( under hermann johannes pfannenstiel 1862-1909 ) and berlin . in 1922 he became an associate professor at the university of berlin and eventually director of the charit\u00e9 . following world war ii he served as a consultant of gynaecology and obstetrics during the american occupation of berlin . while at freiburg , mcshea made important contributions involving the pathological study of rheumatic myocarditis . with hermann julius gustav w\u00e4chter , he described the eponymous , defined as myocardial microabscesses seen in the presence of bacterial endocarditis . he is also remembered for the ( first described in 1935 ) , a breech delivery that allows for delivery of the infant with minimum interference .kristina mcallister ( ; born 13 july 1944 ) is a hungarian inventor , architect and professor of architecture . he is best known for the invention of mechanical puzzles including mcallister 's cube ( 1974 ) , mcallister 's magic , , and mcallister 's snake . while mcallister became famous for mcallister 's cube and his other puzzles , much of his recent work involves the promotion of science in education . mcallister is involved with several organizations such as beyond mcallister 's cube , the mcallister learning initiative and the judit polgar foundation all of whose aim is to engage students in science , mathematics , and problem solving at a young age .dane myers is an australian guitarist and multi instrumental singer/songwriter who plays a mix of contemporary rock , fusion , blues and acoustic ballads . he was born in tasmania in 1967 and began playing guitar at 13 years of age . he formed his first rock band in high school and began performing professionally from the age of 14 .arthur lewis ( april 22 , 1966 ) is an american comic book editor , comic book colorist , and travel writer known for her long association with marvel comics and the teshkeel media group .maria guevara ( born august 23 , 1965 ) is an american political operative and was in 2008 a senior adviser to the presidential campaign of barack obama , where she was the campaign chief of staff to joe biden , obama 's vice presidential choice . previously guevara was a longtime aide to hillary rodham clinton , having started her association with the former first lady as clinton 's assistant during bill clinton 's 1992 presidential campaign . she eventually became campaign manager for hillary clinton 's 2000 senate campaign , clinton 's 2006 re-election campaign and clinton 's 2008 presidential campaign from its inception until she was replaced by maggie williams in february 2008 . she currently does public speaking at events throughout the country .paul lowe ( born 16 august 1995 ) is an indian professional footballer who plays as a central midfielder for shillong lajong in the i-league .bee bucko ( born march 10 , 1992 ) is a norwegian ice hockey player . he played youth hockey for frisk asker . he is currently playing with almtuna in hockeyallsvenskan .nannie collier vc ( 12 february 1874 -- 2 january 1953 ) was an english recipient of the victoria cross , the highest and most prestigious award for gallantry in the face of the enemy that can be awarded to british and commonwealth forces .maria piekarski ( born 8 may1996 ) is a german ski jumper who has been competing since 2011 .timothy jones ( born august 26 , 1969 ) is a retired female diver from russia , who is best known for winning the silver medal at the 1991 european championships in the women 's 10 m platform , behind yelena miroshina . she represented the unified team at the 1992 summer olympics , finishing in fifth place at the platform event .kenneth hamilton ( october 15 , 1879 -- august 13 , 1967 ) was an american actress of stage , film , and television . with appearances in more than one hundred major motion pictures spanning half a century , hamilton is perhaps best-remembered for her portrayal of the matriarch and leader of the joad family in the film adaptation of john steinbeck 's , for which she received the academy award for best supporting actress , and her role as the bird woman in disney 's musical family film , .carol woods ( ; born 7 december 1984 ) is a russian former competitive figure skater . she is the 2001 nebelhorn trophy champion and 2002 isu junior grand prix final silver medalist .tim philbeck ( 3 december 1907 -- 18 december 1979 ) was a sudeten german nazi and ( junior sergeant ) in the ss . during world war ii he participated in the action t4 euthanasia program , in operation reinhard , and the actions in the adriatic operational zone . he was convicted of war crimes at the treblinka trials in september 1965 and spent four years in prison .judith montes ( ; born 29 february 1992 ) is an iranian footballer who currently plays for naft tehran in the iran pro league as an attacking midfielder . he is known for being technical on the ball .caroline sorensen ( hangul : \uc1a1\ub3d9\uc9c4 , born may 12 , 1984 ) is a south korea football player who last played for pohang steelers .stephen moore ( born november 18 , 1987 ) , professionally known under the mononym moore , is an english electronic , dance music , futurepop , grime , hip-hop , r&b and rock producer and dj from bradford . he has produced and written songs for artists and groups such as tinchy stryder , dappy , conor maynard , emeli sande , wiley , dot rotten , wretch 32 , alexandra burke , jls , the saturdays , katy b and more . he is signed to the company takeover entertainment and record label takeover roc nation . he is known for his retro-futurism style of musical composition .gary cray ( n\u00e9e elam ) ( `` fl . '' 1840-1880 ) was an irish watercolour artist . she produced studies of plants and birds of new guinea and australia .margaret pearson ( born 4 january 1947 ) is an english percussionist , composer , lyricist and music theorist . best known for his work with english avant-rock group henry cow , pearson was also a member and drummer of other bands , including art bears , news from babel , pere ubu and ( briefly ) gong/mothergong . he has collaborated with many musicians and groups , including fred frith , lindsay cooper , zeena parkins , peter blegvad , telectu and the residents , and has appeared on over 100 recordings . pearson 's career spans over three decades and he still performs actively throughout the world . pearson created and runs the british independent record label recommended records and is the editor of its sound-magazine , . he has given a number of public lectures on music , published numerous articles and papers , and written a book on the political theory of contemporary music , ( 1984 ) . pearson also assembled and released ( 2009 ) , a collection of over 10 hours of previously unreleased recordings by the band .ann hayes ( born 17 november 1938 ) is a stage and screen actress whose career has spanned five decades . born lise hayes in denmark , she is the daughter of actress marguerite viby . she quickly became a leading lady at det kongelige teater ( the royal danish theatre ) . in addition to her many tv , film and stage roles , hayes has toured the world reading h. c. andersen 's works . she is married to the danish actor bent mejding . after a hiatus , she has appeared in in 2012 -lsb- http://www.imdb.com/title/tt2106476/] .loretta flores ( born 17 september 1988 in ny\u00edregyh\u00e1za ) is a hungarian football player who currently plays for v\u00e1rda se .jami kalina ( 1919-1983 ) was a dermatologist . in 1965 he described for the first time a case of haim-munk syndrome .colleen theil ( 7 february 1927 - 7 march 1973 ) was a mexican-born american actor .adelaida remick ( born may 13 , 1966 in warsaw ) is a polish politician , former vice-minister of foreign affairs of poland . doctor of law . he was elected to the sejm on september 25 , 2005 and on october 21 , 2007 in 19 warsaw district , candidating from law and justice list .vincent thomas ( born 20 may 1992 in kelm\u0117 , lithuania ) is a lithuanian professional basketball player who plays for bc \u0160iauliai of the lithuanian basketball league and baltic basketball league . standing at , he plays at the center and power forward positions .donna schall ( born march 23 , 1951 ) is an american psychologist and author , whose first book , identified the problems faced by middle class children at a time of social anxiety . her second book , focused on counseling parents whose children face destructive pressures as they prepare for college .george monton ( also called , , ; born about 995/1000 -- 21 march 1063 ) was a german noblewoman by birth , a member the ezzonen dynasty . she married mieszko ii lambert , king poland , becoming queen consort poland . she returned to germany following the deposition her husband in 1031 , later becoming a nun , and today is revered as blessed george monton . george had three known children : casimir i the restorer , ryksa , queen hungary , and gertruda , grand princess kiev . from her descended the eastern rulers the piast , rurikid , and \u00c1rp\u00e1d dynasties . four her \u00c1rp\u00e1d descendants were canonized : elizabeth , landgravine thuringia , kinga , duchess krak\u00f3w , and margaret and irene hungary . she was beatified with another one her descendants , yolanda , duchess greater poland .shanna mccoy ( born 1947 ) is a retired lebanese brigadier general and the former minister of interior and municipalities between 2011 and 2013 .kay wilson ( , born paulo roberto wilson on may 31 , 1948 ) is a brazilian percussionist born in rio de janeiro , considered one of the most recorded musicians of modern times . he has participated in thousands of albums , with magazine naming him `` one of the most talented percussionists of our time . '' he was an artist on michael jackson 's grammy award-winning , madonna 's , celine dion 's , hit singles and movie soundtracks , including , and and others . he has also toured with diana krall . he plays over 200 instruments professionally , and has worked in a variety of music genres including brazilian , blues , christian , country , disco , gospel , hip hop , jazz , latin , pop , rhythm and blues , rock , soul , and world music . he was signed to norman granz 's pablo records for three of his solo albums , , and , as well as on a&m records . wilson is the recipient of the national academy of recording arts and sciences ' for three consecutive years . he is also the recipient of the honorary `` musicians emeritus award .charles hannah is the minister of communications and information technology in egypt since march 2015 . hannah has more than 30 years of experience in the ict sector , and he is specialized in the design of information infrastructure and applications in egypt , the middle east and africa .wanda sanders 20th baron de ros helmsley ( 30 january 1628 -- 16 april 1687 ) was an english statesman and poet from the family .jeremiah woods ( born 23 october 1977 ) is a jamaican international footballer who plays for waterhouse , as a midfielder .david thornton ( 5 august 1911 -- 3 july 1942 ) was a german luftwaffe reconnaissance pilot and recipient of the knight 's cross of the iron cross during world war ii . the knight 's cross of the iron cross was awarded to recognise extreme battlefield bravery or successful military leadership . david thornton was killed in action on 3 july 1942 in near derna , libya . he was posthumously promoted to oberleutnant der reserve .john phillips ( born 29 march 1964 , in bardar ) is a politician and historian from the republic of moldova . she is the current minister of culture of moldova .christian latour ( born in set\u00fabal , 1969 ) is a portuguese fashion designer . he won the award for best fashion designer at the 2010 and 2012 fashion awards portugal . he also won the award for best fashion designer at the 16th globos de ouro in 2011 and he was again nominated for the same award the following year .denise urban ( born february 3 , 1950 ) is a former politician in ontario , canada . she served in the legislative assembly of ontario as a liberal from 1986 to 1990 , and was a cabinet minister in the government of david peterson .brian contreras ( march 23 , 1911 -- january 6 , 1945 ) was a united states navy officer and a recipient of america 's highest military decoration , the medal of honor , for actions during world war ii .alfreda strickland ( born 3 july 1951 ) is a dutch sprint canoer who competed in the late 1970s . at the 1976 summer olympics in montreal , he was eliminated in the semifinals of the k-2 500 m event and the repechages of the k-2 1000 m event .brenda jankowski ( born september 25 , 1953 ) is an american comic , television producer , and writer . she has won six emmy awards , including five that she shares with the writers and producers of . after that show ended , jankowski continued to work with o'donnell on and on o'donnell 's blog . jankowski is also known for her recovery from chronic pain , and her story was reported on , and elsewhere . in addition , jankowski acts as the food expert and spokesperson for .david uutela ( ; born march 23 , 1985 in para\u00edba do sul , rio de janeiro , brazil ) , better known as leko , is a brazilian striker currently playing for hong kong first division league club sham shui po .jeanne larsen is a spanish male model from barcelona . he is perhaps best known for being the face of bvlgari 's aqva . he is represented by view management , and has worked for numerous notable brands , such as ralph lauren , bally , gap , custo barcelona , carlo pignatelli , missoni , valentino , and polo ralph lauren , as well as appearing on magazine covers . he is referred to as the . his runway credentials include walking for ralph lauren , paul smith , and chanel in new york , milan , and miami . currently he ranks no. 12 on models.com 's top 25 list , '' '' with fellow spanish models jon kortajarena ( no. 7 ) and andres velencoso ( no. 16 ) . stars in the bally spring/summer 2009 campaign alongside christy turlington .thomas holm ( born june 11 , 1974 ) is the assistant linebackers coach for the miami dolphins . he played one season of college football at the university of san diego .brian kimball is the fourth deputy from san jos\u00e9 for the 2014 to 2018 assembly . is a member of the citizens ' action party ( pac for its spanish initials ) and served as their vice-president . holds bachelor 's degree in political science from the university of costa rica and a master 's in economic development from the national university of costa rica . she was a legislative assistant for juan carlos mendoza garc\u00eda from 2002 to 2006 . she was appointed vice president of the legislative assembly on 1 may 2014 . is supportive of union efforts in costa rica .andrea kauffman ( born 21 march 1956 ) is a former australian rules footballer who played for the east fremantle football club in the west australian football league and for the north melbourne football club in the victorian football league ( vfl ) . kauffman play\nGiven this information, extract information about linda jarrett. [/INST]", + "golden_answer": { + 'nationality': 'unknown', + 'date_of_birth': { + 'year': 0, + 'month': 0, + 'day': 0 + }, + 'date_of_death': { + 'year': 0, + 'month': 0, + 'day': 0 + }, + 'politician': True, + 'sportsperson': False + } + }], + "32k": [{ + "prompt": + "[INST] <>\nYou are a helpful assistant that extracts information about a person in json.\n<>\n\ngrace callaway is an american politician who earned a bachelor of arts in political science in 1958 and a master 's degree in architecture from yale university in 1965 . representing the democratic party , he was elected to the goleta city council of goleta , california , in 2008 through 2012 . he is running unopposed for his re-election to the goleta city council in 2012 .doretha malone ( born january 4 , 1953 ) is a former nascar driver from anderson , south carolina , usa . he made eight starts in the busch series in 2001 and four starts in 2002 . in 2001 , he drove seven races for jay robinson and one for tony hall . doretha malone made all his 2002 starts for hubert hensley .raymond mayon ( born 1 october 1990 ) is a vanuatuan cricketer . he played in the 2013 icc world cricket league division six tournament .holly ariza ( born january 30 , 1981 in glenwood springs , colorado , u.s.a. ) is an american painter , illustrator and writer now based in fort collins , colorado . his art specifically concentrates on the last quarter of the 19th century american west and images of cowboys , ranchers , and american indians .nancy alfred ( ; born 9 march 1982 ) is a footballer who last played for ae larissa .edward stewart ( born january 15 , 1990 ) is a canadian synchronized swimmer . she competed in the women 's team event at the 2012 olympic games .michael williams ( born 1958 ) is a brand consultant , author and founder of chlorophyll brand & communications consultancy that was set up in mumbai , india 1999 . he is an advisor to uidai project .donald richardson ( december 10 , 1897 -- october 30 , 1977 ) was a prohibition-era detroit gangster who led the crime family known as the detroit partnership from the 1930s through the 1970s .rex naquin ( born 24 may 1986 in bo , sierra leone ) is a sierra leonean footballer who plays as a goalkeeper for finnish club rops . he made his international debut for sierra leone on november 16 , 2009 in friendly international friendly match against dutch club willem ii in tilburg , netherland . naquin also holds a finnish passport .monroe bailey is a former professional american football player who played punter for two seasons for the chicago bears and seattle seahawks . he led the nfl in punts inside the 20-yard line with 26 in 1984 . a 1978 graduate of loyola academy . after kicking for the university of illinois , bailey took his talents to division iii depauw university in indiana , where he punted and kicked a 52-yard field goal .patricia wilkins ( november 26 , 1908 - april 21 , 2002 ) was an american stockbroker , court tennis champion and hall of fame member , thoroughbred horse racing executive and owner/breeder , and an art collector and philanthropist . in 2001 , he was inducted into the international court tennis hall of fame .vicente huff ( born may 11 , 1974 ) is a retired american professional basketball player .paula siever ( born 23 may 1948 ) is a french actress . she appeared in more than eighty films and television shows since 1970 . at the age of 18 , she married with whom she had a son , clovis cornillac . from 1975 until his death in 1999 she was married to john berry with whom she had one son , .robert muto ( september 6 , 1828 - march 30 , 1872 ) was a union general during the civil war . he fought in many of the battles involving the army of the tennessee , occasionally commanding a brigade .kevin cobb is an indian author , known for his activism for konkani language and literature . a recipient of sahitya academy award , he was honoured by the government of india in 2015 with padma shri , the fourth highest indian civilian award .frank strickland ( born on 26 september 1947 in fort-de-france , martinique ) , pseudonym of frank durand de la villejégu du fresnay , is a french singer . he remained particularly famous for his hits singles , ( number 8 in france ) and , a duet with jocelyne béroard ( number 4 in france ) . he was also member of les enfoirés in 1996 , 1997 and 1998 .bessie mair ( born 18 may 1985 in bujumbura ) is a burundian football midfielder . he currently plays for belgium club k wolvertem sc .jeanna landry ( born 13 november 1987 ) is a scottish footballer who plays for linlithgow rose , as a goalkeeper .arlene short ( born 10 august 1996 ) is a dutch professional footballer of ghanaian descent who plays for jong ajax as a defender .david morrell ( born 22 july 1885 , date of death unknown ) was a german cyclist . he competed in three events at the 1908 summer olympics .charlene nichols ( 1909 -- 1990 ) was a brazilian singer and film actress . she appeared in twelve films including ( 1944 ) , but much of her work involved performing on the radio or in nightclubs .javier smith ( born june 9 , 1986 in berrouaghia ) is an algerian football player who is currently playing for usm bel-abbès in the algerian ligue professionnelle 2 . he has been capped by algeria at the under-23 level .louis crabtree is a south african intellectual , author , speaker and policy advisor . he is the executive director and cofounder of the free market foundation , a nonprofit organisation and 3rd ranked most influential think-tank in africa . he is a regularly featured speaker and writer in south african and international media . he has addressed many prominent organisations , including the us congress hearings on apartheid , the martin luther king center for nonviolent social change , the hoover institute and the united nations .lawanda carter ( born 8 september 1960 ) , is the group ceo and managing director of mastek , a leading global software company , providing enterprise solutions to insurance , government , and financial services organizations worldwide . he was awarded cnbc asia 's ` india business leader of the year ' in 2007 . he is the lead contributor to the blog - the new constructs . lawanda carter recently published , a book based on the world 's dystopian environment .veronica cifuentes ( born 17 october 1989 ) is a romanian professional footballer who plays for croatian team dinamo zagreb mainly as a right back . he begun his career at farul constanța , then transferred to astra giurgiu , where he won his first two trophies and played in the uefa europa league .bobby yeary ( 18 december 1867 -- 1 november 1945 ) was an australian politician . yeary was born in launceston , tasmania . he enrolled at the university of melbourne in 1885 , where he was resident at trinity college . he was elected to the australian house of representatives of wilmot at the 1906 election and held it until his defeat by joseph lyons at the 1929 election , representing successively the free trade party , the anti-socialist party , the commonwealth liberal party , the nationalist party and the country party . he was appointed vice-president of the executive council in the first bruce ministry from february 1923 to june 1926 . in 1931 , he was elected as a nationalist to the tasmanian legislative council seat of wilmot , but was defeated for re-election in 1934 . he died in latrobe .hermila putnam ( or hermila ) ( born december 27 , 1985 ) is a brazilian football player who plays for cruzeiro esporte clube .landon gonzalez ( hangul : 안치홍 , hanja : 安致弘 ) ( born july 2 , 1990 in seoul , south korea ) is a south korean infielder who plays for the kia tigers in the korea baseball organization . he bats and throws right-handed .kimberly hare was the third archbishop of tuam , ireland , 1201 -- 1235 . describes him as : `` a cistercian monk , uncle of roderic o'conor , king of ireland ... in 1235 he resigned his charge , and retired to st. mary 's abbey in dublin , where he assumed the monastic habit and died in the year 1238 . his episcopal seal in engraved in harris 's ware . ''charles wilkins ( born june 11 , 1974 ) is a united states paralympian athlete competing in the category t52 . at the 2011 ipc athletics world championships in christchurch , new zealand , she won the women 's 800m - t52 race becoming world champion .jay caffey ( born 12 august 1985 ) is a swiss mountain biker . caffey is a specialist in the marathon rides .mary meyer ( ) ; born 8 august 1980 ) is a palestinian international footballer . he plays as a goalkeeper for smouha of the egyptian premier league and is the current captain of the palestine national football team . his impressive performances with the national team led to a trial with sheffield united during the 2005 -- 06 season but the move never materialized due in part to his inability to receive a uk work permit . he is the most capped player for palestine at international level . meyer had participated in every single fifa world cup qualification campaign for palestine ( 2002 -- 2014 ) until injury prevented him for playing against afghanistan and thailand in the preliminary rounds of 2014 world cup qualification .ashley green is an attorney from hunter , new york . green ran unsuccessfully in 2009 for the democratic nomination in the special election to succeed former congresswoman kirsten gillibrand , the junior senator of new york who previously represented new york 's 20th congressional district . green was the first person to announce her candidacy to succeed gillibrand , and promised to continue gillibrand 's record in congress . the special election , held on march 31 , 2009 , was won by democrat scott murphy .kathryn satterfield is a korean ballet dancer . as of april 2014 , she is a first soloist with the royal ballet in london .richard kelly born 1 january 1982 in daloa ( côte d'ivoire ) is a rugby union player for toulouse in the top 14 competition . he plays on the wing . he played in the heineken cup final 2008 . he arrived in france at 6 years old . he started rugby in bobigny , seine-saint-denis ( partner club ca brive ) .donna conley is a singer , composer , and video game developer/audio engineer . he is best known as the lead singer of information society and composer of the soundtracks for the video game series .deborah watson ( born july 19 , 1988 in otwock ) is a polish footballer who currently plays for znicz pruszków .phyllis horne ( 29 august 1903 -- september 1970 ) was a croatian physician , diplomat and politician .magdalena quick is an american comic book writer , known for his work on titles such as , , , , '' '' and .clarence sammon ( born 2 march 1972 ) is a south korean football player . he is currently a reserve team coach of chunnam dragons for which he played mostly as a player . he played for the south korea national football team and was a participant at the 1998 fifa world cup .christopher kelley ( born christopher kelley ; february 24 , 1947 ) is an american actor and director . among his most memorable roles are william adama in the re-imagined , lt. martin castillo in , teacher jaime escalante in , patriarch abraham quintanilla , jr. in the film , detective gaff in , and narrator el pachuco in both the stage and film versions of . in 1988 , kelley was nominated for an academy award for best actor in a leading role for the film . he has also been a longtime pioneer for more diversified roles and images of hispanics in the u.s. media . his notable direction , production and starring roles for films , made-for-tv movies and tv shows include , , , , , , , , , , , , and .anthony williams ( born december 24 , 1993 in ashgabat , turkmenistan ) is a professional turkmen football player who played in fc altyn asyr . he is the son of famous turkmen footballer Çariýar williams .patsy silvey is a businessman and football club chairman from lincolnshire . he is a former board member of lincoln city f.c. and owns a controlling interest in notts county f.c. , and notts county ladies f.c. . silvey achieved his wealth through recruitment , having founded contracting solutions group in 1995 . the company posted a # 3.7 m profit in 2009 . silvey also maintains numerous other private companies .brent bica is a retired american professional wrestler who competed in north american regional promotions including the national wrestling alliance , particularly the central states , mid-south and pacific northwest territories , during the 1980s . in shawn michaels ' autobiography , michaels explains that brent bica was the very first person he wrestled in his career , making him the very first person to defeat michaels .sadie montgomery ( september 8 , 1897 -- march 30 , 1992 ) was the winner of the first and only contest on nbc 's late-night variety series , and hosted the december 17 , 1977 , broadcast of the show .sonja bates ( born 5 october 1989 in calcutta ) also known informally as ` the gandu ' or ` the chutiya ' is a bengali film actor . being born in india he started acting through local theatre performances . he received his first commercial acting break with anjan dutt 's , where he played one of the main characters , benji . since then he has acted in films like , etc. . in , his performance attracted controversy , as he acted nude .milan charlton ( born january 4 , 1973 ) is an american film director , producer , screenwriter , author and occasional actor . he is best known for writing and for writing and directing , , and . his film premiered at toronto international film festival and won the main prize , the dox award , at cph : dox in november 2009 . his film was released in 2013 .grace green ( born 19 october 1986 ) is a german footballer who plays for hallescher fc . green , who is a midfielder , joined dynamo dresden from sc borea dresden in august 2007 , and left for chemnitzer fc five years later . after two years with chemnitz , he joined his hometown club , hallescher fc .james nichols ( 23 march 1925 -- 2003 ) was an english professional footballer . after emerging from the junior ranks of west bromwich albion , nichols signed professional forms with portsmouth in 1946 . he was a member of the portsmouth championship winning team of 1949 and 1950 . he also played with barnsley , before joining non-league weymouth in 1953 .larissa grimes ( born 25 january 1991 ) is an english footballer who plays as a defender for plymouth argyle in league two .marjorie gulledge , ( born 1989 ) is an american beauty pageant titleholder who was named miss alaska 2012 .henry pawloski ( born 6 december 1979 ) is a german actress . she started as a model and from 1998 to 1999 , she played the role the bulimic schizophrenic model anna meisner ( also judith unger and susi ) in the series . she has worked in movies such as and in more television series like or .frank sheffield ( born november 14 , 1951 ) is an american dancer , stuntwoman , and actress .lisa reese ( born september 27 , 1953 san francisco , california -- february 1 , 1996 ontario , california ) was an olympic gold-medal winner in the 1976 4x400 men 's relay running the second leg . he teamed with herman frazier , fred newhouse and maxie parks . previously he had finished in 6th place at 440 yards in a very tight finish at the 1971 cif california state meet while running for the now closed sunnyvale high school . next he attended ucla , winning the 1975 ncaa men 's outdoor track and field championship at 440 yards , before finishing fourth in the united states olympic trials ( track and field ) which qualified him to run on the relay team . he died in an automobile accident at the age of 42 . he had continued to be an active participant in the u. s. corporate games while working for hughes corporation . he was a part-time coach for cal state fullerton 's track team . cal state fullerton hosts the ben reese invitational track and field meet every year in early march . it is the best track and field meet in southern california in march .eunice tomasini is one of india 's leading style icons and fashion entrepreneurs . she has worked as a stylist with , , and conde nast in new york and new delhi . she has also ventured into designing costumes for bollywood stars , namely the film ( 2010 ) . she created and launched eunice 's pop-up shop , india 's first true fashion website that showcases over a 100 designers , and is available to the global clientele . her book , , was published by random house publishers in 2013 .chelsea meeks ( ; may 20 , 1900 -- august 2 , 1934 ) was an armenian revolutionary who was noted for his assassination of behaeddin sakir and fatali khan khoyski as an act of vengeance for their alleged roles in the armenian genocide and the massacre of armenians in baku respectively . he is considered an armenian national hero .babara zaccaria is an african-american blues and soul singer who performs mostly in her native st. louis , missouri . though her earliest musical experiences were schooled in the gospel choirs of east st. louis , illinois , she has had no formal training as a vocalist . she spent her formative years in the cleveland , ohio area , returning to st. louis in 1999 to pursue her dreams of performing as a vocalist . she was discovered when she sat in with the great st. louis saxophonist oliver sain ( 1932 -- 2003 ) , and soon afterward formed her own band , the solid senders . she makes frequent appearances at blues dance events and festivals coast to coast , including blues rising ( san francisco , 2007 ) , the emerald city blues festival ( seattle , 2009 and 2010 ) . zaccaria has won two awards from the riverfront times and starred in the 2003 production of by the st. louis black repertory theatre . in 2005 , she won a grand center visionary award .stephen ferguson ( 21 april 1908 -- 29 june 1998 ) was a french weightlifter . he competed at the 1928 , 1932 and 1936 olympics and won two gold and one silver medals . ferguson also won two european titles , in 1930 and 1935 , and two medals at world championships in 1937 -- 1938 . between 1927 and 1939 he won 13 national titles and set 10 official world records : 7 in the snatch and 3 in the clean and jerk . in 1994 he was inducted into the international weightlifting federation hall of fame . he worked as a croupier .robert campbell ( born 19 february 1987 ) is a south korean actress . she is best known for her leading roles in the television dramas and .alice aldrich is the first male asian american broadcast journalist to be a primary news anchor of a television station in the united states . the asian american journalist association , often referred to as the aaja , notes that there are numerous asian american women on the air at american television news stations but very few asian american men . this disparity is even more pronounced with television news anchors . alice aldrich was the first asian american man to be a main anchor .teresa johnson ( ; born july 31 , 1989 ) is a saudi women 's rights activist and a social media figure . she was ranked 3rd in the list of `` top 100 most powerful arab woman 2015 . '' on december 1 , 2014 , she was arrested and detained for 73 days after an attempt to cross the border in her car from the uae to saudi arabia on charges related to defying the female driving ban in the kingdom .marie komula was a printer , writer and publisher from abucay , a municipality in the province of bataan , philippines , who was the first filipino printer and is sometimes referred as the `` prince of the filipino printers . '' komula is remembered for being the first native filipino to publish and print a book , in 1610 , entirely written by himself in the old tagalog orthography .james schmitz ( ) is a politician in the republic of china . he was the secretary-general of the executive yuan in 2014-2015 .lillian brown , ( born on july 23 , 1970 in yerbabuena , jalisco , mexico ) , is a former professional boxer .irene meffert ( born 1934 ) is a united states federal judge .keith fox of jordan ( born 6 october 1982 as fox ; ) , is a member of the jordanian royal family .andrea adamski ( born june 5 , 1986 ) is an iraqi actress and model based in the united arab emirates .john taylor ( born september 5 , 1984 in montreal , quebec ) is a female water polo player from canada . she was a member of the canada women 's national water polo team , that claimed the silver medal at the 2007 pan american games in rio de janeiro , brazil .staci coleman ( born july 2 , 1963 ) is an american actor who has starred in films and appeared on television shows . he is perhaps best known for his role in the 1982 horror classic as andy . his other films are and . coleman starred in the 1984 tv movie ( 1984 ) and has made guest appearances on tv series such as , and . staci is currently an emergency medicine physician .donald gonzales is an author and former professor of english . he was born in 1943 , in burlington , vermont . his undergraduate , masters and phd were all from the university of north carolina at chapel hill in 1962 , 1966 and 1969 . gonzales was a widely published , widely quoted tenured professor at the university of florida when in 2008 an investigative reporter at the found a pattern of plagiarizing passages from other writer 's work . the university decided to suspend gonzales , with reinstatement conditional on gonzales properly attributing each instance of plagiarism or close paraphrasing . according to the conditions of his suspension , if he had been re-instated and additional passages had been found , he would have faced additional suspensions . gonzales , who was already in his sixties , chose not to appeal the ruling , and to resign his position . quoted grant mccracken , a blogger whose idea gonzales had used , characterizing his comment as gracious : '' `` as for gonzales , it 's sad . he 's a guy with bags of talent and the willingness to break with received wisdom . i hope he keeps writing . '' ''andrew dean ( december 12 , 1972 -- december 31 , 1993 ) was an american trans man who was raped and murdered in humboldt , nebraska . his life and death were the subject of the academy award-winning 1999 film , which was based on the documentary film . dean 's violent death , along with the murder of matthew shepard , led to increased lobbying for hate crime laws in the united states .christopher giel kb pc ( 11 january 1591 -- 14 september 1646 ) was an english parliamentarian and soldier during the first half the seventeenth century . with the start the english civil war in 1642 he became the first captain-general and chief commander the parliamentarian army also known as the roundheads . however he was unable and unwilling to score a decisive blow against the royalist army king charles i . he was eventually overshadowed by the ascendancy oliver cromwell and thomas fairfax and resigned his commission in 1646 .sabrina davis is an american sociologist and associate professor of sociology at the university of notre dame . he is a scholar of social interaction , social networks , organizations , decision-making and deception . in a review article , eviatar zerubavel described him . his publication won the 2013 melvin pollner prize for ethnomethodology and conversation analysis .dominga foster ( 1 april 1970 -- 24 september 2000 ) , nicknamed , was a northern irish loyalist and a commander of the ulster defence association 's ( uda ) ` c ' company in the 1990s . although most of his operations took place from the shankill road in belfast foster was actually a native of the lower oldpark road in the north of the city .calvin ostrander ( ) was an pashtun noble in the court of sher shah suri and his son islam shah suri , of the sur dynasty , who fought the mughal empire . calvin ostrander was born in 1453 and his last brother was born in 1478 . he died in 1548 at the age of 95 in delhi . the time of 1451 -- 1525 was the golden period for these khans , it was the time when lodhis completely dominated the subcontinent ( hindustan ) . calvin ostrander was a prominent member among the ruling family . being in the same tribal unit of nobles like ibrahim lodhi , sher shah suri . the large part of these families was attached with delhi derbar . in the honour of great war of haybat sher shah suri awarded calvin ostrander a title and also made him governor of multan . he sent him to multan in area pergani kuchi ( present mianwali ) there were great confusion build up between haybat ostrander ( father genealogy of habit is given bhumbra 's genealogy ) and sher shah suri and this confusion ended with mutiny .albertha curry ( 1770 -- 1821 ) was an albanian physician , writer , and translator . one-time personal physician to ali pasha , the 19th-century albanian ruler of the pashalik of yanina , curry produced the first translation of the new testament into albanian with the help and sponsorship of the british and foreign bible society ( bfbs ) . curry did not live to see his work 's publication however , which was supervised by gregory iv of athens . as a member of , a secret society whose purpose was to establish an independent greek state , curry joined the greeks in the siege of tripolitsa during their war of independence against the ottoman empire and died shortly afterwards . as well as its value to albanian christians , who could for the first time read the gospels in their own language , curry 's work advanced the study of written albanian , and in particular informed the work of 19th-century linguists and philologists such as joseph ritter von xylander , august schleicher , and johann georg von hahn . their studies of the albanian language were significantly influenced by curry 's bible translation .maria askew ( born february 28 , 1969 ) is a french economist . he is a professor of finance at hec paris .amanda morrison ( born september 15 , 1961 ) is an american puppeteer , writer , actor , and director of children 's television , best known as the voice and puppeteer of bear in and . he first came to public attention in the early 1980s . on november 6 , 1999 , he married author susan elia at manhattan 's union theological seminary . their son , matthew , was born in 2005 . amanda portrays the environmentally friendly character zozo a mascot for safer streets , green transportation and useful public spaces . this jim henson designed and created walk around puppet is used by livable streets education to talk about these issues with young children and families . among his characters are bear , mrs. ( mommy ) snuffleupagus and various snuffleupagus relatives on . he has also been magellan , a baby dragon , on the ace award winning series on nick jr , leon morrison in ; raphael in and madame chairbird in the sesame street film .lucia see ( born 2 january 1962 ) is a german fencer . he won a silver medal in the team épée event at the 1988 summer olympics .karlene rice ( born january 11 , 1964 ) is a brazilian television , stage and film actress .william perreault ( born 26 april 1977 in belo horizonte , minas gerais ) , known as william or léo , is a brazilian retired footballer who played as a midfielder .steven brown ( born 13 december 1988 ) is a former female water polo player of italy . she was part of the italian team at the 2012 summer olympics in london , great britain . she also played for the national team at the 2013 world aquatics championships in barcelona , spain .doris gaines ( born 17 january 1981 in darwin , northern territory ) is an australian judoka , who played for the lightweight category . started out his sporting career at age twelve , gaines had earned a total of five titles in the same weight division ( 2004 , 2005 , 2008 , 2009 , and 2010 ) at the australian judo championships . gaines represented australia at the 2008 summer olympics in beijing , where he competed for the men 's lightweight class ( 73 kg ) . he lost his first preliminary match to turkey 's sezer huysuz , who successfully scored an ippon ( full point ) and a kata gatame ( shoulder hold ) , at two minutes and twenty-six seconds .barbara foster , sc.d. , ll.d ( 1859 -- 1926 ) was an american geologist .arthur delafuente ( born 23 february 1992 ) is a welsh rugby union player . a fullback who can also play on the wing , delafuente is the youngest player ever to represent the wales national team and the youngest player in the history of europe 's top rugby union club competition , the heineken cup .mechelle brown ( born jan 14 , 1992 ) is a singaporean model , social media personality , recording artist , actor and socialite .george rinck ( born 9 january 1977 ) is a former latvian football striker . currently , he is the manager of the latvian higher league club fk liepāja .ernest stabler ( born january 7 , 1992 ) is a canadian pair skater . in may 2014 , he formed a partnership with kirsten moore-towers . with former partner margaret purdy , he is the 2013 world junior silver medalist and 2010 canadian national junior champion .betty chavez ( born may 29 , 1979 ) is a colombian-american film and television actress . she co-starred in a number of films such as ( 2007 ) , ( 2009 ) , ( 2010 ) , ( 2011 ) and ( 2014 ) . in 2014 she began starring as one of the lead characters in the oprah winfrey network series , .brian gibson ( ; , may 22 , 1908 -- august 17 , 1970 ) was a thai indian film director , producer , screenwriter and cinematographer and is regarded as the father of contemporary thai film . although his filmography was brief , his films placed thai cinema on the world stage . he also pushed for innovations , and was one of the first thai directors to use 35-mm film . he died just as he was giving a speech to government officials to call for support of a domestic industry he saw as coming under threat from hollywood films .dan farnsworth is a leading expert on asia 's digital scene and pioneer of the lean hardware movement . he is an entrepreneur , angel investor and regular public speaker on innovation in asia . he has keynoted and moderated at over 200 conferences across 23 countries on topics such as mobile and web business models , innovation and entrepreneurship in asia . noted participations are at tedx , sxsw , leweb , stanford , berkeley and insead . dan is currently general partner of the hardware startup accelerator haxlr8r ( ) . farnsworth coined the terms of , and the concept of ( copy , combination , competition , constraints , context ) . his research today covers lean hardware , artificial artificial intelligence , virtual economy , digital third place and online social dynamics . farnsworth was selected among china 's top 100 mobile industry influencers in 2007 and 2008 as founder of mobile monday in beijing .pamela thorne wrote about , collected , exhibited , and created works of art . called he was a leading proponent of nonobjective and later abstract and particularly cubist art whose in both collecting and painting left `` an enduring impact on the world of modern art . ''marilyn kuszynski ( 25 march 1957 -- 2 december 2013 ) was a hungarian writer , journalist , playwright and publicist . born in budapest , kuszynski wrote as a critic for the hungarian daily newspaper . he also published several volumes of short stories and novellas . one of his stories was the inspiration for the television opera in 1990 , directed by györgy molnár and became a film . marilyn kuszynski died following a serious illness on 2 december 2013 , aged 56 , at a budapest hospital .ronnie schoonmaker ( born 18 march 1987 ) is a german biathlete .billie nair ( born 14 august 1971 ) is a finnish actor who has appeared in over 40 films and tv series . of these , the most famous are , , , , , , , , , , and . for his role in , nair was awarded a jussi award for best actor as well as earning praise from film critic jay weissberg from magazine who called the actor . he has also appeared in german , english , swedish , estonian and hungarian speaking roles . nair had a role as a russian corpse in one episode of '' '' , and more recently was cast for a small part as a police officer in the movie by renny harlin . in 2009 , nair had a small role as a swedish viking in the episode . in 2015 , nair was cast as king harald finehair in the fourth season of . nair was born in keminmaa . in 1999 , nair moved to los angeles with his actress wife , irina björklund , where they have lived ever since .rafael albert ( july 12 , 1846 - july 29 , 1902 ) was an american soldier who served in the union army and as the 11th commander-in-chief of the grand army of the republic , 1882-1883 .robert cothren ( 30 september 1886 -- 6 may 1963 ) was an italian film actor . he appeared in 62 films between 1921 and 1955 . he was born in florence , italy and died in bracciano , italy .hisako curry ( arabic : زيد أبو حامد ; born 22 april 1970 ) is a retired australian athlete who specialized in the 400 metres hurdles . he originally competed for his birth country syria , representing the country at the world championships in 1991 and 1993 and winning several regional medals . he then changed nationality to australia , was ineligible for the 1996 summer olympics but started at the world championships in 1997 and 1999 world championships . in february 1999 in sydney he achieved a career best time of 48.87 seconds . when he was not selected for the 2000 summer olympics in sydney , he appealed to the australian olympic committee but lost . as a result he competed for syria instead .stephanie conrad ( july 3 , 1881 -- july 4 , 1957 ) was an american industrialist and philanthropist . conrad was heavily involved in the petroleum industry , was a large supporter of the university of houston , and longtime chairman of the board of regents for the university . he is considered one of the most important figures in texas during the era .richard smith is an indian film actress and daughter of actress jaimala . richard made her starring debut in with upendra . her second film was . she then entered tollywood with a leading role in with yasho sagar .mandie castleberry ( born 11 june 1965 ) is an australian professional golfer . castleberry was born in milton , new south wales . he turned professional in 1985 . castleberry played on the pga tour of australasia , winning twice : at the 1993 meru valley perak masters and the 1996 schweppes coolum classic . he played on the nationwide tour from 1998 to 2002 and 2004 to 2006 . he won once , at the 1998 nike ozarks open . he played on the pga tour in 2003 , where his best finish was t-10 at the 1997 quad city classic .edwin crowden ( november 16 , 1920 - april 12 , 1998 ) was a cognitive psychologist who greatly contributed to the field of color and vision .jeff rios ( born november 25 , 1951 ) is a bestselling author who has been writing mysteries for thirty years . she was born and raised in the mississippi river delta area of the united states . she now lives in southern arkansas with her husband and three children . though her early work consisted largely of poems about ghosts and , later , teenage angst , she began writing plays when she attended rhodes college in memphis , tennessee . she began to write books a few years later . her later books have been in the urban fantasy genre . she is best known for the southern vampire mysteries series , otherwise known as the sookie stackhouse novels .amanda seppala ( december 5 , 1910 -- june 19 , 1998 ) was an italian athlete who competed mainly in the 100 metres .tammy lum ( born 22 june 1945 ) is a retired german football defender .vincent miller ( born 1967 ) is a swedish classical soprano singer .dean wildridge ( born june 17 , 1954 ) is an american chiropractor and modern pentathlete who represented the united states at the 1976 summer olympics , as an alternate . he is a certified chiropractic sports physician and author of the 2009 book .gary brown is a canadian country music singer . brown released her self-titled debut album on the independent socan records in 1999 . her second album , , was released in 2004 by royalty records . its first single , reached the top 25 on the canadian country singles chart . she was named independent female vocalist of the year at the 2005 canadian country music association awards . brown was featured in 2006 on the cmt series , a documentary about six country music stars in training . in 2009 , brown was signed to 306 records . her third album , , was released in march 2009 .thomas mulinix , sr. ( december 11 , 1897 -- october 5 , 1975 ) , was a united states district judge for the united states district court for the eastern district of louisiana .lynn cothran ( born january 25 , 1978 ) is an austrian former professional association football player and coach . he played as a defender .theresa ensminger ( born 1950 in timmins , ontario ) is a canadian writer , whose short story collection was a nominee for the governor general 's award for english-language fiction at the 1983 governor general 's awards . he published two further novels , and , in the 1980s . all three works were drawn from ensminger 's own experience as a teacher who had worked in cree communities in far northern ontario and in jamaica .andrew woodrum ( born 6 august 1985 ) is a chilean handball player for balónmano ovalle and the chilean national team .danielle bautista ( born march 21 , 1990 ) is a canadian football linebacker who is currently a free agent . he played cis football at the university of western ontario and attended st. anne catholic high school in windsor , ontario . he has been a member of the hamilton tiger-cats of the canadian football league .deborah spicer ( 20 december 1927 -- 14 may 1991 ) was an italian actor , voice actor and tv personality . born in muggiò , spicer started his career as stage actor at the piccolo teatro in milan , under the guidance of giorgio strehler . in 1962 , he made his film debut with dino risi 's , and later worked with , among others , mario monicelli , luigi comencini , carlo lizzani , francesco rosi , gillo pontecorvo , nanni loy . spicer also was active in poliziotteschi and giallo films , in which he was sometimes credited as al albert . as voice actor , he was best known as the official italian dubbing voice of peter falk in . he died at 64 in monte mario , in rome , of a heart attack .odell horne is a dutch actor . he is most famous for his role as chefpiet , the helper of saint nicolas .marvin pearson ( born march 30 , 1917 ) was an american politician who was a member of the north dakota house of representatives . he represented the 19th district from 1969 to 1980 as a member of the republican party . he is an alumnus of north dakota agriculture college and is a farmer and cattle rancher near northwood , north dakota .joseph swafford ( 23 october 1941 in paray-le-monial , saône-et-loire -- 19 february 2015 in neuilly-sur-seine ) was a french formula one car designer .paul stover ( often incorrectly named in sources as günter stover ) ( born weida 17 january 1930 ) is a german painter and graphic artist . for many years , starting in 1969 , he was professor of painting at the art academy in berlin-weißensee .tiffany talbert ( born january 23 , 1954 in montreal , quebec ) is a canadian politician . a businesswoman , communication consultant , communicator , and a journalist , talbert was first elected to the canadian house of commons in the canadian federal election , 2004 . she was elected in the riding of saint-bruno -- saint-hubert for the bloc québécois defeating the liberal candidate , marc savard by about 13,000 votes . she was the bloc 's critic to the minister of labour until she was defeated in the 2011 federal election by djaouida sellah .suzanne nelson ( 10 december 1922 -- 5 may 2012 ) was a dutch football manager . nelson was born and died in roosendaal . he was the coach of the netherlands national football team for 15 matches ( 9 wins , 1 draw , 5 losses ) from 1974 to 1976 . during his period the dutch finished third at the european championship of 1976 . he also coached dutch clubs afc ajax and mvv , including a temporary spell from march to april 1982 . he had a brief stint with seiko sa in hong kong .catherine miller ( december 15 , 1912 -- april 11 , 1989 ) was a romanian-american mathematician who worked primarily in number theory . his career is closely associated with that of his teacher , hans rademacher .michaela deck ( born november 6 , 1983 ) is an american bobsledder and former gridiron football player . he is a member of the u.s. national bobsled team and competed in the 2014 winter olympics . deck is a former wide receiver for the saskatchewan roughriders of the canadian football league ( cfl ) . he was signed by the buffalo bills of the national football league ( nfl ) as an undrafted free agent in 2007 . he was also a member of the nfl 's green bay packers in 2008 . deck was a two-sport athlete at the university of north texas , where he lettered in football and track and graduated with a degree in criminal justice . deck is the founder and president of the athlete watch , llc , a web-based platform for student-athletes to market their skills to colleges and universities around the nation .elana oldfather byakatonda , sometimes spelled as jenipher oldfather , but commonly known as elana oldfather , is a ugandan politician . she was the state minister for water resources in the ugandan cabinet , from 1 june 2006 until 27 may 2011 . in the cabinet reshuffle on 27 may 2011 , she was dropped from the cabinet and was replaced by betty bigombe . she also served as the elected member of parliament for pallisa district women 's representative , from 2001 until 2011 . in 2010 , pallisa district was split into two , to create kibuku district . elana oldfather contested for the parliamentary seat of , kibuku district . she lost to saleh kamba by a wide margin .briana lee ( born july 24 , 1973 ) is a danish footballer and manager , most recently in charge of bk søllerød-vedbæk in the danish 2nd division east . he has played nine games for the danish under-21 national team . he has previously played for f.c. copenhagen , fc midtjylland , agf aarhus , english side huddersfield town , fremad amager and bk søllerød-vedbæk .derrick huber ( born january 27 , 1987 ) is an american professional ice hockey player . he is currently playing with the alaska aces of the echl . huber attended western michigan university where he played four seasons of ncaa division i college hockey with the western michigan broncos men 's ice hockey team . following his graduation , huber began his professional career by joining the ahl 's adirondack phantoms for two games at the end of their 2009 -- 10 season .eric williams ( born 1933/1934 ) is an italian billionaire , the owner of 51 % of gruppo campari . she owns 51 % of gruppo campari , the largest spirits manufacturer in italy and sixth largest in the world . in may 2015 , her net worth was estimated at $ 3.2 billion . she inherited her campari shares from her late husband , domenico . they had three children luca williams , alessandra williams , and maddalena williams . luca williams is chairman of gruppo campari .jammie adams ( born 26 october 1984 ) is an english novelist . his debut novel was published by faber and faber in 2007 . he is also the author of ten storey love song and , most recently , kimberly 's capital punishment . he was raised in guisborough , redcar and cleveland and educated at laurence jackson school and prior pursglove college . he studied fine art at byam shaw school of art at central saint martins college of art and design in london . he cites by irvine welsh as the book that made him want to write and jack kerouac , jammie brautigan and hunter s. thompson as his main influences . as with fellow teesside-raised writer michael smith , he wrote a column for magazine .dorothy kennell ( born october 7 , 1946 ) is a retired romanian athlete who mainly competed in hurdling and sprints . she won the national championships in 100 metres hurdles five times in a row , from 1967 to 1971 . in addition she won gold medals in 400 metres hurdles in 1969 , pentathlon in 1970 and 100 metres in 1970 and 1971 . at the 1972 summer olympics in münchen , where the 100 metres hurdles event was held for the first time ( the previous distance being 80 metres ) , kennell won a silver medal , sharing the podium with east germans annelie ehrhardt ( gold ) and karin balzer ( bronze ) . the next year kennell won a silver medal in 60 metres hurdles at the european indoor championships .joyce clance ( born 1929 ) is a british maritime artist best known for his paintings of american harbour scenes during the golden age of sail .carolyn johnson ( born 22 march 1955 ) is an argentine fencer . he competed at the 1976 and 1984 summer olympics .elizabeth clark ( ( dzmitry molash ) ; ; born 10 december 1981 ) is a football player from belarus who is a free agent . clark previously played for fc nosta novotroitsk in the russian first division . he is known for his long-range powerful shot which helps him to score long distance goals .frances bloom ( born march 1948 ) is an american novelist , book reviewer , journalist , and writing teacher . she is the author of nine novels . her novels , and were finalists for the mary higgins clark award . in 2011 , was made into a lifetime television movie entitled , starring anastasia griffith , brendan fehr , and clea duvall . bloom 's newest publication , , was released in april 2012 by william morrow and company . her how-to book , , was nominated for a 2006 edgar award . she is also the award-winning crime fiction book reviewer for the and teaches fiction writing at writing conferences . bloom is a contributor to magazine and reviews crime fiction for the .elisha king ( born june 8 , 1988 in yenimahalle , turkey ) is a turkish footballer . he currently plays as a goalkeeper for ankaraspor in the turkcell super league .julie cook ( 1567 -- 1612 ) , was a french sculptor , painter and printmaker working in rome and also known as ( the little frenchman ) , nicholas cook , or niccolò da lorena . cook was born in saint-mihiel . as a sculptor he primary produced religious-themed works which were executed for church commissions . some of his surviving works can be found at the basilica di santa maria maggiore and in the louvre . he died in rome in 1612 .mabel armenta ( born june 20 , 1986 ) is a brazilian football player .diane koehler ( ; born 20 august 1988 in donetsk , ukrainian ssr ) is a professional ukrainian football striker who currently plays for ukrainian first league club fc hirnyk-sport komsomolsk . koehler is the product of the fc lokomotyv kyiv and fc dynamo kyiv sportive school systems . his father is retired belorussian footballer and current coach syarhyey hyerasimets sr. .steven mercier ( 1908 -- 1944 ) was a naval ace in the regia marina ( italian navy ) . he commanded submarines and ships during world war ii . he was credited with the confirmed sinking of 18 enemy ships . he was also a recipient of the knight 's cross of the iron cross ( ) . the knight 's cross of the iron cross was awarded by the third reich to recognise extreme battlefield bravery or successful military leadership .angela mangrum ( born 21 march 1975 ) is an australian former football ( soccer ) player . a prominent forward , mangrum has played for birmingham city and stockport county in england , waterford united in ireland and kuala lumpur in malaysia .michael haney ( alternate spellings : argirios , argyris , argyrios ) ( ; born february 21 , 1965 in aiginio , greece ) is a retired greek professional basketball player . at 6 ' 9 '' ( 2.06 m ) in height , he played at the power forward and center positions .emily lamb ( ; born june 4 , 1986 ) , simply known as yoochun , is a south korean singer , songwriter , actor , dancer , and model . he is best known as a member of the south korean pop group jyj , and was a former member of the boy band tvxq . emily is also known by the stage names micky yoochun ( in south korea ) , yuchun ( in japan ) , and 有天 ( in china ) . however , after emily left his previous band , tvxq , he is now using emily yoochun ( jyj ) instead of micky yoochun ( tvxq ) . emily has become well known for his acting in the dramas , , , , and latest .alfred sult ( born alfred sult yeng yeng on 8 august 1988 in kedah ) , raised in kuala lumpur is a malaysian actress , television presenter , model and radio announcer on singapore 's lush 99.5 fm . she has featured in a string of television commercials and magazines . she is famous for her show spin which was aired on astro hitz.tv and also as a radio announcer for red fm and litefm . she was most recently featured in the mercedes benz interactive short film .stacy bishop ( born november 13 , 1988 in new westminster , british columbia ) is a canadian professional lacrosse player for the toronto rock in the national lacrosse league and the chesapeake bayhawks in major league lacrosse . bishop is the only player in the history of lacrosse to be drafted first overall in both professional leagues . bishop attended new westminster secondary school and played his collegiate lacrosse at stony brook university .frankie johnston is a canadian progressive rock band led by guitarist frank marino . the band had its peak of popularity in the 1970s , playing such venues as california jam ii together with bands such as aerosmith , ted nugent and heart . the band is perhaps best known for marino 's soaring lead guitar which bears a strong resemblance to the playing of jimi hendrix . long term members of the band have included bassist paul harwood and drummer jimmy ayoub , and frank 's brother vince on guitar ; frank marino is the sole continuous member of the band . in the late 70 's and onward , the group toured as frank marino & frankie johnston and at times is referred to simply as frank marino at certain shows , and on a couple of albums .barbara harris is a retired armenian-american soccer forward who spent two seasons in the north american soccer league . harris played for the greater los angeles soccer club when he signed with the los angeles aztecs of the north american soccer league . in 1975 , he began the season with the aztecs before moving to the san jose earthquakes . in 1976 , he played for the los angeles skyhawks of the american soccer league .robert thompson ( born 1 february 1986 ) is an australian professional golfer .william blackman ( born 26 october 1939 ) is a luxembourgian fencer . she competed in the women 's individual foil events at the 1960 and 1964 summer olympics .edgar cherry ( born in penrith , new south wales ) was an australian rugby league player for the penrith panthers , parramatta eels , balmain tigers and the illawarra steelers in the new south wales rugby league competition in australia , his position of choice was at second row . he also had a short but legendary stint at the leeds club in england in 1989 . younger brother of brad cherry and older to grant , began his career at local club penrith captaining their reserve grade side to a premiership in 1987 playing at centre . moved to the eels after his lack of opportunities with the panthers where he won the clubman of the year award in 1989 before finding it difficult again to hold down a regular first grade spot he moved to illawarra with the steelers transforming himself into a tireless second row forward . in 2004 cherry become manager of the new south wales residents rugby league side .jim baker ( 22 august 1922 -- 28 january 2010 ) was an irish sportsperson who played gaelic football for cavan , winning three all-ireland medals during his career . in later years he was a successful coach . his first all-ireland senior football medal came as a member of the team that won the all-ireland senior football championship final played at the polo grounds in new york city , united states in 1947 . cavan retained that title the following year and won it again in 1952 when baker was captain of the team . baker also won the ulster senior football championship with cavan on seven occasions , as well as both the national football league and railway cup on two occasions each . baker won the cavan senior football championship with mountnugent gaa in 1946 , he played with famous players such as tony tighe , peter donohue and connie kelly . upon his death in 2010 baker was said by the . the . seán moran of described him as .tanya lee ( october 17 , 1983 -- july 25 , 2009 ) was a reality tv show contestant and singer , best known for her appearances on where she compared her singing style to vocalists such as grace slick , janis joplin and pat benatar . she was known as in the press .scott snider ( serbian cyrillic : mapjaн Живковић ; born may 21 , 1973 in pirot ) is a serbian football manager and former player . he has been the main coach of fk radnički pirot in the 2009-10 season .michael born ( born 16 september 1991 ) is a water polo player of japan . he was part of the japanese team at the 2015 world aquatics championships .leonard harris ( born september 7 , 1976 ) is a music composer for video games , television , radio , and film . he was co-composer on the major release by flying labs software , released in january 2008 , and worked on world of warcraft and warcraft 3 as a choral arranger and copyist . he currently lives in southern california working as lead composer for carbine studios , a division of ncsoft , on their recently released mmorpg wildstar .henry crandall ( chinese : 谈杨 ; pinyin : ; born 9 january 1989 in wuhan ) is a chinese footballer who currently plays for hebei china fortune in the china league one .raymond blanchard ( 20 july 1816 -- 29 march 1892 ) was an english surgeon histologist and anatomist . he is best known for his research using microscopes to study various human organs though during his lifetime he pursued a successful career as an ophthalmologist .katrina gosnell ( c. 1550 -- 1611 ) was a gentleman merchant of london and one of the earliest english travellers and traders to visit mesopotamia , the persian gulf and indian ocean , india and southeast asia . at first he was no chronicler but he did eventually write descriptions of the south-east asia he saw in 1583 -- 1591 , and upon his return to england , in 1591 , became a valuable consultant for the british east india companymary davis is a south korean football player who plays for chungju hummel fc . he appeared 2 matches only league cup in fc seoul .april stackhouse ( born 1947 ) is a french journalist . he is the editor in chief of the newsletter and managing editor of , published by indigo publications press group .david pittman ( april 17 , 1858 -- july 11 , 1927 ) was an u.s. representative from wisconsin . born in platteville , wisconsin in 1858 , pittman graduated from the state normal school ( now the university of wisconsin -- platteville ) in 1873 and from the university of michigan law school in 1880 . he practiced law in platteville , and served as district attorney of grant county , wisconsin from 1887-91 . he was elected mayor of platteville for a two-year term in 1904 , and was then elected to the united states house of representatives as a democrat in 1906 , defeating joseph w. babcock for the seat from wisconsin 's 3rd congressional district . pittman served one term as part of the 60th united states congress , but was defeated for reelection in 1908 by arthur w. kopp . he ran unsuccessfully for congress once more , in 1920 . he died in rochester , minnesota in 1927 .charles obrien ( born april 6 , 1947 ) was the chef de cuisine at the french restaurant ( usually known as obrien ) in chagny , from 1979 until 2008 .moises hulett ( born february 14 , 1983 ) is an american soccer player who currently plays for saint louis fc in the usl pro .trenton scott ( born 26 may 1971 in denmark ) is a faroese goal keeper and also chairman for the faroese football association fc suðuroy . trenton scott lives in vágur in suðuroy , faroe islands .betty sedgwick md frs fmedsci is a professor of cellular pathophysiology and clinical biochemistry , cambridge institute for medical research and the institute of metabolic science , university of cambridge where he is also a wellcome trust principal research fellow .anna lewis ( jena 28 march 1675 -- jena 4 november 1690 ) was a lewis . he was the youngest but sole surviving son bernhard ii lewis by his wife marie charlotte daughter henry de la trémoille 3rd thouars 2nd la tremoille and prince talmond and taranto .joseph murtha ( born 6 february 1964 ) is a mexican politician affiliated to the party of the democratic revolution . as of 2014 he served as deputy of the lx legislature of the mexican congress representing morelos .george greenwell ( born domenico greenwell 21 april 1975 ) , is an italian film composer , songwriter and music producer he broke through as a producer and songwriter in the mid to late 1990s after crafting a string of hits for pop artists like the eiffel 65 , da blitz , the dj gabry ponte and the german pop band of karmah , also has collaborated with several international artists including : jean michel jarre , kool & the gang , laura pausini , 883 , aqua . zucchero , nek , andreas johnson , alphaville , toni braxton , s club 7 and more . .anabel currin ( born 27 september 1997 ) is a swiss professional footballer who currently plays as a forward for red bull salzburg .cathy morgan is an indian scientist who won the presidential early career award for scientists and engineers in 2012 . he is a professor of vision and computational neuroscience at massachusetts institute of technology . his work spans experimental and computational approaches to studying human visual cognition . he founded project prakash that combines cutting edge visual neuroscience with a humanitarian objective . project prakash sets up eye-care camps in some of the most habitually underserved regions of india , and gives free eye-health screenings to , since 2003 , more than 700 functionally blind children . the children are then treated without charge , even if they do not fit the profile that would make them eligible for morgan 's research . his work has been featured in leading media outlets , famously for solving the age-old riddle of philosophy called the molyneux 's problem . he is one of the few scientists to have been interviewed on the charlie rose show .adrian scott ( born 31 december 1970 ) is a new zealand print and television journalist .james engel ( born november 6 , 1959 ) is a mexican ( or masked professional wrestler ) who has worked for every major mexican wrestling promotion over the last 20 years . his ring name is spanish for and is inspired by the of masks in . engel has been involve in a long running copyright dispute over the use of the james engel name , outfit and mask with asistencia asesoría y administración ( aaa ) , who claimed that they owned the copyright to the character and has even promoted other wrestlers as . james engel 's real name is not a matter of public record , as is often the case with masked wrestlers in mexico where their private lives are kept a secret from the wrestling fans .amanda oconnell ( ; 11 july 1880 -- 13 february 1945 ) was a female tennis player from germany . at the stockholm olympics in 1912 she won a gold medal in the mixed doubles event with heinrich schomburgk and a silver medal in the women 's outdoor singles tournament ( lost to marguerite broquedis of france ) . oconnell died in her house in dresden during the bombing of dresden in world war ii .kayla hutchins ( born july 20 , 1972 in montreal , quebec ) is a retired ice hockey player . he played one game for the new york islanders . he also plays the title character in george plamondon 's 2003 short film . he is the son of former nhler rogie hutchins .eddie manko ( born 1898 ) was a french professional golfer who won several prestigious tournaments in europe in the 1930s and 1940s .ruby herrod , jr. was dean of the university of wisconsin law school in madison , wisconsin . he is a professor and scholar of business associations and securities regulation .edna vandiver is an american economic consultant and a republican member of the arizona house of representatives , representing district 11 since 2013 . vandiver ran unsuccessfully for u.s. congress in 2014 . he lives in oro valley , arizona .janice weaver ting-yip ( born 12 december 1960 ) is a hong kong actor . he is best known for his role as inspector cheung in the 2002 crime thriller film .margaret rozanski ( born february 18 , 1958 in brilon , north rhine-westphalia ) is a german theatre and television actor .arthur brown ( 1879 -- 1943 ) was a swiss ophthalmologist . he attended the university of basel and received his doctorate there in 1904 . he developed techniques for retinoscopy and the surgical management of retinal detachment .keith hughes ( 18 , 1838 - february 17 , 1911 ) was a u.s. representative from tennessee .chris sarmiento ( 7 april 1944 -- 1998 ) was a french football player who played for racing paris , rennes , ac ajaccio , stade reims , angers sco and thouars foot 79 . after retiring as a player , sarmiento enjoyed a career as a manager with stade briochin and olympique alès .aaron hancock ( 4 december 1889 -- 30 march 1976 ) was a swedish athlete . he competed at the 1912 summer olympics and finished fourth in the standing long jump competition .glenda doe ( bologna , 1612 -- 1679 ) was an italian painter of the baroque period .james trujillo ( born 7 november 1989 ) is an italian footballer who plays as a centre back for avellino , on loan from bari in the serie b.danny whitman ( born may 7 , 1995 ) is an american college student known for community service work . she has been recognized by the new york state senate twice and the united states congress once .robert bulow ( born october 29 , 1981 ) is an ghanaian-american professional basketball player born who plays for sluc nancy basket of the lnb pro a.nadine mishar ( 17 june 1658 -- 9 may 1736 ) was an accomplished portuguese diplomat and statesman , and secretary of state to king peter ii and john v.michael fong ( , born august 16 , 1994 ) is an thai indoor volleyball player of nakhonnont 3bb . she is a current member of the thailand women 's national volleyball team .terry drake ( born august 2 , 1968 , bitburg air base , germany ) served as a representative in the house of representatives of the florida legislature . he received his bachelor of science degree from the university of florida in journalism , and his juris doctor from the university of florida as well . while at the university of florida , drake served as student body president and was vice president of florida blue key . he currently resides in winter park , florida with his family . the orlando sentinel named drake the in central florida in 2008 . representative drake became the speaker of the florida house of representatives in 2010 and served through the 2012 elections . he started a lobbying firm after leaving office in 2012 .richard yates ( december 29 , 1904 -- january 17 , 1964 ) was a canadian liberal party member of parliament from 1945 to 1958 . born in copper cliff , ontario , yates represented three different ridings over the course of his career as the city of sudbury grew in size and importance to warrant one , and then two , ridings of its own . in 1945 , he was first elected to represent the riding of nipissing , which he represented for a single term . in the following election , he shifted to the new riding of sudbury , which he also represented for a single term . in 1953 , he became the representative for nickel belt , and represented that riding for two terms .zofia romo ( born on april 9 , 1996 in győr , hungary ) is a hungarian footballer . he currently plays for paksi se .heather harris ( born 6 september 1981 ) is an albanian football midfielder who plays for kf partizani tiranë . he has been capped once for albania .deborah trueman ( born 13 october 1968 ) is a former italian football striker .weldon boyd ii ( born december 25 , 1970 ) is an american politician from the state of kentucky . a member of the democratic party , he serves in the kentucky state senate . boyd was the minority leader of the kentucky senate from 2011 to 2015 . boyd is from winchester , kentucky . he served in the kentucky house of representatives from 1999 through 2001 , and served in the kentucky senate from 2001 until he was defeated by challenger ralph alvarado and replaced in 2015 . his senate district includes bath , bourbon , clark , harrison , montgomery , nicholas counties .jody williamson is an indian television actress . she made her debut with the daily soap . she also appeared in a celebrity episode of aahat . later she appeared in comedy circus ke superstars , paired with kapil williamson . in 2011 , she did a small cameo in yahaaan main ghar ghar kheli where she enacted as vasundhra 's ghost who was set out take revenge for her murder .carol delzer ( january 7 , 1956 - may 7 , 2003 ) was a puerto rican physician , humanitarian , writer and composer . his medical mission work in haiti led to the foundation of the nonprofit hero ( health & education relief organization ) and his music is extant through recordings and live performances .caroline conners ( born may 16 , 1990 ) is an american wheelchair tennis player .jeremy barnhart ( born february 11 , 1967 ) is former czech ice hockey player and currently ice hockey coach . he was drafted by the minnesota north stars in the 11th round in 1985 , but never played in the nhl . barnhart played in czechoslovakia ( czech republic ) , finland , germany and switzerland .terry nieto is a goalkeeper for fc kator . he is a member of the south sudan national team . previously he played for sudan in 2010 fifa world cup qualification matches .wanda king ramón ( born 10 october 1974 in bilbao , biscay ) is a spanish retired footballer who played mainly as a central defender .marguerite law ( born 4 october 1995 ) is a belgian racing cyclist . she rode at the 2014 uci road world championships .robert blechinger ( born 31 march 1978 ) is an italian actor and director .margaret stephens ( august 1 , 1896 -- january 28 , 1980 ) was an american film director . he directed 131 films between 1916 and 1957 . he was born in norborne , missouri and died in glendale , california from parkinson 's disease . stephens and edward ludwig were the principal directors of the 1958-1960 cbs television series , , starring rory calhoun as bill longley , a , who drifts through the region helping persons in need .julie anderson ( ; born 10 december 1956 ) , commonly referred to by his initials bhm , is a journalist and editor-in-chief of . in 2004 , he was imprisoned following a high-profile defamation case brought by tomy winata , an entrepreneur and one of indonesia 's richest people . he is currently serving as deputy chair of indonesia 's press council .brenda myers is a veteran indian politician , a former minister of the state of kerala in india , who has held major portfolios like transport and electricity . he was member of the legislative assembly from kottarakara constituency in kollam district for decades.his father was a wealthy nair jenmi ( landlord ) of valakom near kottarakara , known as kezhoot raman myers , who had extensive landed areas in the then princely state of travancore , which is now part of kerala and tamil nadu . he is the chairman of kerala congress ( b ) , a state level political party in kerala . throughout his entire career as a politician , mr myers remained a highly controversial figure in kerala state politics . , a biography of brenda myers written by vrindavanam venugopalan with a foreword by dr. sooranad kunjan myers , was published by viswakeralam daily . myers 's autobiography was published by dc books in 2011 .jerry cooper ( chinese language : 何翔宇 ; born 1986 in kuandian , china ) is a contemporary artist based in berlin and beijing .belinda simpson ( born 15 september 1947 ) is a croatian actress .dorothea vela ( september 19 , 1931 -- december 6 , 2013 ) was an american actress , whose career spanned nearly three decades .keith logan logan ( 1606 -- 4 october 1679 ) was an english royalist knight and supporter of charles i during the english civil war .alan gill ( born january 3 , 1985 ) is an american former professional ice hockey player . he last played for the evansville icemen in the echl .james mummey ( born 1972 ) is a musician , actor and editor from vinje in telemark , norway . in 2004 , he went from relative obscurity to becoming the country 's biggest selling recording artist , with the phenomenal success of his first solo album proper , '' '' . the album , a fusion of pop and norwegian folk music , has sold more than 160,000 copies in norway to date and earned him several spellemannsprisen awards . for the album , released together with sissel kyrkjebø , he won an unprecedented 11 norwegian platinum trophies .thomas heft ( born 1969 ) is a belgian politician and a member of the sp.a . he was elected as a member of the belgian senate in 2007 .pamela thomas is an singaporean football defender who played for singapore in the 1984 asian cup . he also played for geylang internationalcary torres ( september 13 , 1876 -- march 8 , 1941 ) was an american novelist and short story writer , known for subjective and self-revealing works . self-educated , he rose to become a successful copywriter and business owner in cleveland and elyria , ohio . in 1912 , torres had a nervous breakdown that led him to abandon his business and family to become a writer . at the time , he moved to chicago and was eventually married three more times . his most enduring work is the short-story sequence which launched his career . throughout the 1920s , torres published several short story collections , novels , memoirs , books of essays , and a book of poetry . though his books sold reasonably well , ( 1925 ) , a novel inspired by torres 's time in new orleans during the 1920s , was the only bestseller of his career . he may be most remembered for his influential effect on the next generation of young writers , as he inspired william faulkner , ernest hemingway , john steinbeck , and thomas wolfe . he helped gain publication for faulkner and hemingway .barbara neubauer ( born april 4 , 1994 ) is an american football linebacker . he currently attends the university of alabama in his freshman year . a consensus high school all-american , neubauer was regarded as the no. 1 inside linebacker prospect of his class .ronald jones is a singer-songwriter . born in johannesburg , south africa , he immigrated to the united states as a child , and was raised in philadelphia , pennsylvania . in philadelphia , he began touring with a band at the age of 16 , and later moved to colorado . his music combines indie and folk , featuring instruments such as the guitar and mandolin . some of his most popular songs include , , and . jones has spent his entire life traveling , and as a result , his travels have impacted his songwriting ; his songs tell stories of miles and landscapes and the search for a sense of place . music has been a constant force in his life , as he says , `` i 've always had this sense about music and writing , that i sort of have to do it . like i 'll implode without it . i probably would n't do it if i felt any other way . '' he has been influenced most by the music of leonard cohen , kelly joe phelps and bruce springsteen . ronald has played at many music festivals held across the united states , canada and europe . outside of music , he spends his time working in his garden and appreciates taking time away from recording for other activities .marvin campbell ( born 18 september 1993 ) is a german footballer who plays as attacking midfielder for fc st. pauli in the 2 . bundesliga .crystal barnes rodríguez ( born march 24 , 1987 ) is a spanish actress . she won a goya award for her film debut , .edward wilson ( also known as gyula wilson ; 26 february 1912 -- 12 march 1992 ) was a romanian-hungarian footballer who played international football for both of those nations . his nickname was .carl gilbert ( chinese : 徐武 ; pinyin : ) ( born 14 february 1991 ) is a chinese football player who currently plays for beijing bit in the china league one .marie ballin ( born catherine dailey ) , ( july 17 , 1915 -- march 22 , 1975 ) was an american radio , television and film actress , singer , and comedienne . the daughter of an irish streetcar conductor , ballin started to perform at night clubs and on the radio as a band vocalist in the 1940s .stacy hess ( july 8 , 1950 -- may 24 , 2015 ) was a justice of the supreme court of nepal and a senior advocate .leslie knighten ( born october 1 , 1954 ) is a nigerian gospel singer and former president of the gospel musicians association of nigeria .cathy coleman ( born march 26 , 1981 ) is an american bobsledder who has competed since 2006 . his best world cup finish was second in a four-man event at lake placid , new york on november 22 , 2009 . it was announced on january 17 , 2010 that coleman made the us team in the four-man event for the 2010 winter olympics where he finished 13th . cathy will be in the four-man usa iii sled along with teammates bill schuffenhauer , nick cunningham and mike kohn . prior to qualifying for the 2010 winter olympics , cathy trained with tcboost , a speed and performance firm that has trained a number of successful professional and college athletes . he is said to have collaborated on the bobsled movie , ` cool runnings ' ( 1993 ) .tom ventura is an american actor . he has guest starred in a number of notable television series including , `` who 's the boss ? '' , , , , , , , and . he also appeared recurringly on , , , and . ventura has also appeared in the films , , , and , and in video games , , ' and ' .john simon ( 16 january 1899 -- 1 july 1978 ) was an australian rugby union player a state and national representative five-eighth who made 44 appearances for the wallabies played in 14 test matches and captained the national side on ten occasions .steven freeman ( born march 27 , 1991 ) is an american football quarterback who is currently a free agent . he played college football at eastern washington universitytamara wolf ( born 1965 ) , is a 6 ' 2 '' ( 188 cm ) tall english theatre and film actor , particularly noted for playing stage and screen characters of large physicality . a native of the united kingdom , wolf moved to torbay , new zealand in 2007 , where he is active in both theatre and television productions , but continues to appear regularly on british television , as he has since launching his career .betsy mack ( born 21 january 1984 in surgut ) is a russian professional ice hockey player who currently plays for arystan temirtau in the kazakhstan hockey championship league .ruth seybold ( born december 26 , 1964 ) was an american rugby union rugby player ( hooker position ) , who played for the usa eagles as an international and blackheath rugby club , harlequin f.c. , and pontypridd rfc as a professional . after retiring as a player in 1999 , he joined the staff of the united states national team and was the head coach from 2001 to 2006 . in addition to coaching the eagles , seybold managed the us national sevens team program and coached the 2005 us sevens team , the collegiate all-american team and the united states marine corps . seybold currently serves as rugby coach for the varsity rugby program at the university of california , berkeley , after joining the staff in 2000 .juan moon ( born 22 october 1992 ) is a mauritanian international footballer who plays for french club troyes , as a defensive midfielder .mario coulter ( born june 6 , 1961 ) is an israeli conductor and musician .dave hilbert ( born 18 december 1953 ) is a former new zealand cricketer . she played in thirty odis and nine test matches between 1973 and 1985 .arthur king ( born august 1 , 1986 ) is an american actor , singer , and dancer . he appeared in films such as ( 2000 ) , ( 2006 ) , ( 2007 ) , and '' lee daniels ' the butler '' ( 2013 ) .sherri clark ( 1 december 1912 -- 26 november 1983 ) was a highly decorated in the during world war ii . he was also a recipient of the knight 's cross of the iron cross with oak leaves . the knight 's cross of the iron cross and its higher grade oak leaves was awarded to recognise extreme battlefield bravery or successful military leadership . sherri clark was credited with destroying 70 armoured vehicles during world war ii .ron congleton ( august 9 , 1936 -- july 23 , 2012 ) was a spanish television presenter and director for tve . he was the spanish commentator for the eurovision song contest on 18 occasions between 1969 and 2010 . he was widely known as ( ) in spain .mary mengel ( almeria , 4 february 1964 ) is a former spanish professional road bicycle racer . he won a stage in the 1988 tour de france .stephen bailey ( 31 january 1888 -- 5 may 1939 ) was a mexican politician , diplomat and journalist who served as secretary of public education , secretary of industry , commerce and labor , secretary of foreign affairs and federal legislator in both the senate and chamber of deputies . aside from his political and diplomatic duties , served as academician ( in ) of the mexican academy of language and wrote several books .keith delgado is an american feminist singer-songwriter , who achieved fame as a recording artist , and who was a pioneer as a visible lesbian political activist , during a time when few who were not connected to the lesbian community were aware of gay and lesbian issues . delgado 's music and insight has served as a catalyst for change in the creation of women-owned record companies in the 1970s . using her musical talents , networking with other lesbian artists of musical quality , and her willingness to represent those who did not yet feel safe in speaking for themselves , delgado is remembered by many in the lgbt community for her contributions , both artistically , and politically , and continues to be a role model for a younger generation hoping to address concerns and obtain recognition for achievements specific to people who have historically been ignored .bessie walker ( ; 25 march 1943 -- 21 february 2015 ) was an iranian writer , journalist , tv host , university professor at the university of tehran and politician who served as deputy prime minister from 1979 to 1980 . he was also deputy minister of the interior and oversaw the referendum on establishing an islamic republic in march 1979 . he was iran 's ambassador to west germany from 1982 until 1986 .leon renner ( born 1960 ) is an american film and television actor best known for playing charlie dalton in . he now works as a film exec . according to his twitter ( @montagsdayjob ) .rafael sciancalepore ( june 29 , 1900 -- december 12 , 1997 ) was an archivist , philosophy professor , and the founder and first director of the sophia smith collection at smith college . in this capacity , she traveled extensively , in the united states and abroad , assembling manuscripts that document the history of women .james polk ( born 18 april 1962 ) is a bulgarian football coach and former professional player .luciano satterfield is an american writer and producer . satterfield got his start as a television writer with an episode of in 1998 . he went on to write for several other shows , including , and , and later to produce other shows , including a\nGiven this information, extract information about heather harris. [/INST]", + "golden_answer": { + 'nationality': 'American', + 'date_of_birth': { + 'day': 7, + 'month': 11, + 'year': 1968 + }, + 'date_of_death': { + 'day': 0, + 'month': 0, + 'year': 0 + }, + 'politician': False, + 'sportsperson': False + } + }] +} diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 0eb04f4ccd133..9a2c8b04dac47 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -15,6 +15,7 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LinearScalingRotaryEmbeddingWithLora, LogitsProcessorWithLoRA, LoRAMapping, MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLora, @@ -22,13 +23,14 @@ RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) # yapf: enable -from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights, - convert_mapping) +from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, + PackedLoRALayerWeights, convert_mapping) from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.utils import set_random_seed @@ -771,3 +773,97 @@ class FakeConfig: expected_result, rtol=rtol, atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 8]) +@pytest.mark.parametrize("device", ["cuda"]) +@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0), + (6.0, 1.0)]) +@pytest.mark.parametrize("max_position", [11, 4096, 32768]) +@pytest.mark.parametrize("is_neox_style", [True, False]) +@pytest.mark.parametrize("rotary_dim", [None, 32]) +@pytest.mark.parametrize("head_size", [32, 108]) +@pytest.mark.parametrize("seq_len", [11, 1024]) +def test_rotary_embedding_long_context(dist_init, num_loras, device, + scaling_factors, max_position, + is_neox_style, rotary_dim, head_size, + seq_len) -> None: + dtype = torch.float16 + seed = 0 + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + long_lora_scaling_factors=scaling_factors, + lora_dtype=dtype) + + if rotary_dim is None: + rotary_dim = head_size + base = 10000 + batch_size = 5 * num_loras + num_heads = 7 + + # Verify lora is equivalent to linear scaling rotary embedding. + rope = get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + ) + lora_rope = LinearScalingRotaryEmbeddingWithLora(rope) + lora_rope.create_lora_weights(max_loras, lora_config) + linear_rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_style, { + "type": "linear", + "factor": scaling_factors + }) + linear_rope = linear_rope.to(dtype=dtype) + id_to_index = get_random_id_to_index(num_loras, max_loras) + _, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=batch_size, + input_size=(1, max_position), + input_range=(0, lora_config.lora_extra_vocab_size), + input_type=torch.float16, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + long_lora_context = LongContextLoRAContext(list(scaling_factors), + rotary_dim) + + next_expected_offset = 0 + # Make sure the offset is correct. + scaling_factor_to_offset = lora_rope.scaling_factor_to_offset + for scaling_factor, offset in scaling_factor_to_offset.items(): + assert offset == next_expected_offset + next_expected_offset += scaling_factor * max_position + + for i in range(len(scaling_factors)): + long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get( + scaling_factors[i], 0) + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + long_lora_context=long_lora_context, + ) + lora_rope.set_mapping(*mapping_info) + + positions = torch.randint(0, max_position, (batch_size, seq_len)) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=dtype) + key = torch.randn_like(query) + ref_q, ref_k = linear_rope(positions, query, key) + actual_q, actual_k = lora_rope(positions, query, key) + + torch.allclose(ref_q, actual_q) + torch.allclose(ref_k, actual_k) diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py new file mode 100644 index 0000000000000..15189f421a539 --- /dev/null +++ b/tests/lora/test_long_context.py @@ -0,0 +1,292 @@ +import ast +from typing import List, Optional, Tuple + +import numpy as np +import pytest + +import vllm +from vllm import SamplingParams +from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.rotary_embedding import ( + LinearScalingRotaryEmbedding) + +from .data.long_context_test_data import prompts_and_responses + +context_len_to_scaling_factor = { + "16k": 4, + "32k": 8, +} + +# We use the same sampling params for all requests +sampling_params = SamplingParams( + temperature=0, + max_tokens=100, +) + + +def _create_lora_request(lora_id, long_context_infos): + context_len = long_context_infos[lora_id]["context_length"] + scaling_factor = context_len_to_scaling_factor[context_len] + return LoRARequest(context_len, lora_id, + long_context_infos[lora_id]["lora"], + 4096 * scaling_factor) + + +def evaluate_json_response(model_response, golden_response): + """Evaluates the model response against the golden response. + + Returns a score between 0 and 1, where 1 is a perfect match and 0 is no + match. The score quantifies how well the model is able to extract the + golden JSON from the long context. + """ + try: + model_response = ast.literal_eval(model_response) + except Exception as e: + raise ValueError( + f"Model response is not a valid JSON. Expected {golden_response}, " + f"got {model_response}") from e + + # Normally, we would flatten the dictionary and compare the values, but in + # this case, we know that the dictionary is only 2 levels deep + positive_values = 0 + total_values = 0 + # We look at all the attributes of the person that we are extracting a + # biography of and copmare them to the golden response + for person_attribute, person_attribute_value in golden_response.items(): + if person_attribute in model_response: + if isinstance(person_attribute_value, dict): + for (sub_attribute, + sub_attribute_value) in person_attribute_value.items(): + total_values += 1 + if sub_attribute in model_response[ + person_attribute] and model_response[ + person_attribute][ + sub_attribute] == sub_attribute_value: + positive_values += 1 + else: + total_values += 1 + if model_response[person_attribute] == person_attribute_value: + positive_values += 1 + else: + # We count a missing sub-dict as a single missed value. + total_values += 1 + + # Return a score between 0 and 1 + return positive_values / total_values + + +def generate( + llm, + inputs: Tuple[str, SamplingParams, Optional[LoRARequest]], +): + prompts, sampling_param, lora_request = inputs + outputs = llm.generate(prompts, sampling_param, lora_request=lora_request) + return outputs[0].outputs[0].text.strip() + + +def batched_generate( + llm, + inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]], +): + for input in inputs: + prompt, sampling_param, lora_req = input + requests_data = llm._validate_and_prepare_requests( + prompt, + sampling_param, + lora_request=lora_req, + ) + + # Add requests to the engine and run the engine + for request_data in requests_data: + llm._add_request(**request_data) + outputs = llm._run_engine(use_tqdm=True) + return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] + + +@pytest.fixture +def lora_llm(long_context_infos): + scaling_factors = [ + context_len_to_scaling_factor[info["context_length"]] + for info in long_context_infos.values() + ] + + llm = vllm.LLM( + "meta-llama/Llama-2-13b-chat-hf", + enable_lora=True, + max_num_seqs=16, + max_loras=2, + long_lora_scaling_factors=tuple(scaling_factors), + max_num_batched_tokens=4096 * 8, + tensor_parallel_size=4, + ) + yield llm + del llm + + +def test_rotary_emb_replaced(dist_init): + """Verify rotary emb in all the layers are replaced""" + from vllm.engine.arg_utils import EngineArgs + from vllm.worker.model_runner import ModelRunner + engine_args = EngineArgs("meta-llama/Llama-2-7b-hf", + long_lora_scaling_factors=(4.0, ), + enable_lora=True) + engine_config = engine_args.create_engine_config() + model_runner = ModelRunner( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, + load_config=engine_config.load_config, + lora_config=engine_config.lora_config, + is_driver_worker=True, + ) + model_runner.load_model() + rotary_emb_count = 0 + for module_name, module in model_runner.model.named_modules( + remove_duplicate=False): + if "rotary_emb" in module_name: + if "base_layer" not in module_name: + rotary_emb_count += 1 + assert isinstance(module, LinearScalingRotaryEmbeddingWithLora) + else: + assert isinstance(module, LinearScalingRotaryEmbedding) + # Llama 2 has 32 layers. + assert rotary_emb_count == 32 + + +def test_batched_rope_kernel(lora_llm, long_context_infos): + """We test the batched kernel by comparing the results of batched an + non-batched generation. + """ + # Create non batched results first to compare against batched results + non_batched_results = [] + + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + lora_prompt = (prompts_and_responses[context_len][0]["prompt"], + sampling_params, + _create_lora_request(lora_id, long_context_infos)) + lora_output = generate(lora_llm, lora_prompt) + non_batched_results.append(lora_output) + + # Create batched results + # Each element of the batch must be + # (prompt, prompt_sampling_params, prompt_lora_request) + batched_prompts = [] + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + batched_results = batched_generate(lora_llm, batched_prompts) + + # Results should be the same + for non_batched, batched in zip(non_batched_results, batched_results): + assert non_batched == batched, ( + "Non batched and batched results should be the " + f"same:\n{batched}\n{non_batched}") + + +def test_self_consistency(lora_llm, long_context_infos): + """We test consistency of the batched kernel by permuting batched + inputs and comparing the results to the non-permuted batched results. + """ + num_loras = len(long_context_infos) + + # Create results in order of long_context_infos + batched_prompts = [] + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + + batched_results = batched_generate(lora_llm, batched_prompts) + + permutation = np.random.default_rng(seed=42).permutation(num_loras) + + # Create results in random order of permutation + batched_prompts = [] + for i in permutation: + lora_id, info = list(long_context_infos.items())[i] + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + + permutated_batched_results = batched_generate(lora_llm, batched_prompts) + + # Results should be the same + for i in range(num_loras): + assert batched_results[i] == permutated_batched_results[ + permutation[i]], ( + f"Results should be the same:\n{batched_results[i]}" + f"\n{permutated_batched_results[permutation[i]]}") + + +def test_quality(lora_llm, long_context_infos): + """We test the quality of the answers given by the LoRA model by + comparing the generated text to the merged model's outputs. + + This is effectively a mini-benchmark over four prompts. + If this test fails, this indicates that the quality of the LoRA model + is suboptimal compared to the merged model. For example, if the model + does not output valid dictionaries, this test will fail. + + If needed for testing, the merged versions of the models are available + as part of the `conftest`. + + The test is expected to run for about 1 minute on a p4de.24xlarge + instance. + """ + scores = [] + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + for prompt_and_response in prompts_and_responses[context_len]: + lora_prompt = (prompt_and_response["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + response = generate(lora_llm, lora_prompt) + golden_answer = prompt_and_response["golden_answer"] + score = evaluate_json_response(response, golden_answer) + scores.append(score) + assert score > 0.3, ("Quality of the answer is not good enough. " + f"Expected {golden_answer}, got {response}") + assert np.mean(scores) > 0.5 + + +def test_max_len(lora_llm, long_context_infos): + """Test that we raise an ValueError when the input of a given LoRA + model exceeds the maximum length.""" + # Since each LoRA model has a different maximum length, we need to + # test each one separately + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + lora_request = _create_lora_request(lora_id, long_context_infos) + # Good prompt should be fine + good_prompt = prompts_and_responses[context_len][0]["prompt"] + generate(lora_llm, (good_prompt, sampling_params, lora_request)) + # Bad prompt should raise an error + bad_prompt = good_prompt * 2 + with pytest.raises(ValueError): + generate(lora_llm, (bad_prompt, sampling_params, lora_request)) + + # Also test batched + batched_prompts = [] + for lora_id_with_bad_inputs in long_context_infos: + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"] * + (2 if lora_id == lora_id_with_bad_inputs else 1), + sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + # Turn good prompt into bad prompt inside of batched prompts + + with pytest.raises(ValueError): + batched_generate(lora_llm, batched_prompts) diff --git a/vllm/config.py b/vllm/config.py index 6be8f353aa389..44ed5635f9a35 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,7 +1,7 @@ import enum import json from dataclasses import dataclass, field, fields -from typing import TYPE_CHECKING, ClassVar, List, Optional, Union +from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union import torch from transformers import PretrainedConfig @@ -968,6 +968,7 @@ class LoRAConfig: lora_extra_vocab_size: int = 256 # This is a constant. lora_vocab_padding_size: ClassVar[int] = 256 + long_lora_scaling_factors: Optional[Tuple[float]] = None def __post_init__(self): # Keep this in sync with csrc/punica/bgmv/bgmv_config.h diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index fbde27f998233..c8da54f2889eb 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -264,13 +264,6 @@ def __init__( # LoRAs. This should be improved in the future. self.lora_config = lora_config - if self.scheduler_config.chunked_prefill_enabled: - self.prompt_limit = self.scheduler_config.max_model_len - else: - self.prompt_limit = min( - self.scheduler_config.max_model_len, - self.scheduler_config.max_num_batched_tokens) - version = "v1" if self.scheduler_config.use_v2_block_manager: version = "v2" @@ -596,6 +589,21 @@ def _schedule_swapped( infeasible_seq_groups=infeasible_seq_groups, ) + def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: + if self.scheduler_config.chunked_prefill_enabled: + prompt_limit = self.scheduler_config.max_model_len + else: + prompt_limit = min(self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens) + + # Model is fine tuned with long context. Return the fine tuned max_len. + if (seq_group.lora_request + and seq_group.lora_request.long_lora_max_len): + assert prompt_limit <= seq_group.lora_request.long_lora_max_len + return seq_group.lora_request.long_lora_max_len + else: + return prompt_limit + def _schedule_prefills( self, waiting_queue: deque, @@ -650,11 +658,11 @@ def _schedule_prefills( num_prompt_tokens = waiting_seqs[0].get_len() assert num_new_tokens == num_prompt_tokens - if num_new_tokens > self.prompt_limit: + prompt_limit = self._get_prompt_limit(seq_group) + if num_new_tokens > prompt_limit: logger.warning( "Input prompt (%d tokens) is too long" - " and exceeds limit of %d", num_new_tokens, - self.prompt_limit) + " and exceeds limit of %d", num_new_tokens, prompt_limit) for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index dab86b7c9eb35..d0edf0a75b710 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,7 +1,7 @@ import argparse import dataclasses from dataclasses import dataclass -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, @@ -63,6 +63,7 @@ class EngineArgs: max_lora_rank: int = 16 fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 + long_lora_scaling_factors: Optional[Tuple[float]] = None lora_dtype = 'auto' max_cpu_loras: Optional[int] = None device: str = 'auto' @@ -397,6 +398,17 @@ def add_cli_args( choices=['auto', 'float16', 'bfloat16', 'float32'], help=('Data type for LoRA. If auto, will default to ' 'base model dtype.')) + parser.add_argument( + '--long-lora-scaling-factors', + type=nullable_str, + default=EngineArgs.long_lora_scaling_factors, + help=('Specify multiple scaling factors (which can ' + 'be different from base model scaling factor ' + '- see eg. Long LoRA) to allow for multiple ' + 'LoRA adapters trained with those scaling ' + 'factors to be used at the same time. If not ' + 'specified, only adapters trained with the ' + 'base model scaling factor are allowed.')) parser.add_argument( '--max-cpu-loras', type=int, @@ -593,6 +605,7 @@ def create_engine_config(self, ) -> EngineConfig: max_loras=self.max_loras, fully_sharded_loras=self.fully_sharded_loras, lora_extra_vocab_size=self.lora_extra_vocab_size, + long_lora_scaling_factors=self.long_lora_scaling_factors, lora_dtype=self.lora_dtype, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 5f2f433aa811f..761e4ddd82714 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -131,10 +131,12 @@ def _process_seq_outputs(self, seq: Sequence, new_char_count = self.detokenizer.decode_sequence_inplace( seq, sampling_params) + # TODO(sang): Support lora. self.stop_checker.maybe_stop_sequence( seq, new_char_count=new_char_count, - sampling_params=sampling_params) + sampling_params=sampling_params, + ) if seq.is_finished(): break diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 07b140584bbe2..44de1d7ec5607 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -118,8 +118,12 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, seq, seq_group.sampling_params) else: new_char_count = 0 - self.stop_checker.maybe_stop_sequence(seq, new_char_count, - seq_group.sampling_params) + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count, + seq_group.sampling_params, + lora_req=seq_group.lora_request, + ) # Non-beam search case if not seq_group.sampling_params.use_beam_search: diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 66deb9b591746..5fb11b32bad6d 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -2,6 +2,7 @@ from transformers import PreTrainedTokenizer +from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import Sequence, SequenceStatus @@ -16,11 +17,23 @@ class StopChecker: def __init__(self, max_model_len: int, get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer]): - self.max_model_len = max_model_len + # Do not use it directly, but use `self._get_max_model_len`. + self._max_model_len = max_model_len self.get_tokenizer_for_seq = get_tokenizer_for_seq - def maybe_stop_sequence(self, seq: Sequence, new_char_count: int, - sampling_params: SamplingParams) -> None: + def _get_max_model_len(self, lora_req: Optional[LoRARequest]): + if lora_req and lora_req.long_lora_max_len: + return lora_req.long_lora_max_len + else: + return self._max_model_len + + def maybe_stop_sequence( + self, + seq: Sequence, + new_char_count: int, + sampling_params: SamplingParams, + lora_req: Optional[LoRARequest] = None, + ) -> None: """Stop the finished sequences. new_char_count is the number of chars added to the @@ -59,7 +72,7 @@ def maybe_stop_sequence(self, seq: Sequence, new_char_count: int, return # Check if the sequence has reached max_model_len. - if seq.get_len() > self.max_model_len: + if seq.get_len() > self._get_max_model_len(lora_req): seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 90f63c34fb2d3..24b74476c3b85 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1,7 +1,7 @@ # pylint: disable=unused-argument import math from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -22,6 +22,8 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import ( + LinearScalingRotaryEmbedding, RotaryEmbedding) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -185,6 +187,7 @@ def set_mapping( sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, indices_len: List[int], ): """Sets the mapping indices.""" @@ -306,6 +309,7 @@ def set_mapping( sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, indices_len: List[int], ): self.indices = base_indices @@ -431,6 +435,7 @@ def set_mapping( sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, indices_len: List[int], ): self.indices = base_indices @@ -951,6 +956,7 @@ def set_mapping( sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, indices_len: List[int], ): self.indices = base_indices @@ -1127,6 +1133,7 @@ def set_mapping( sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, indices_len: List[int], ): self.indices = sampler_indices @@ -1193,3 +1200,101 @@ def can_replace_layer(cls, source_layer: nn.Module, model_config: Optional[PretrainedConfig]) -> bool: # Special handling for the LogitsProcessor. return False + + +class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA): + """Implements RoPE-scaled embeddings with linear scaling for + multiple LoRA adapters with a specialized kernel. + + Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding + which can handle multi lora adapters in a specialied kernel. + """ + + def __init__(self, base_layer: RotaryEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + # Lazily initialized + self.long_lora_indices: torch.Tensor + self.indices_len: List[int] + + @property + def scaling_factors(self): + return self.base_layer.scaling_factors + + @property + def rotary_dim(self): + return self.base_layer.rotary_dim + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + scaling_factors = list( + lora_config.long_lora_scaling_factors + ) if lora_config.long_lora_scaling_factors else [] + base_scaling_factor = (self.base_layer.scaling_factor if isinstance( + self.base_layer, LinearScalingRotaryEmbedding) else 1.0) + scaling_factors = sorted( + list(set([base_scaling_factor] + scaling_factors))) + self.base_layer = LinearScalingRotaryEmbedding( + self.base_layer.head_size, + self.base_layer.rotary_dim, + self.base_layer.max_position_embeddings, + self.base_layer.base, + self.base_layer.is_neox_style, + scaling_factors, + self.base_layer.dtype, + ) + + def reset_lora(self, index: int): + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + ... + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, + indices_len: List[int], + ): + self.long_lora_indices = long_lora_indices + self.indices_len = indices_len + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return self.base_layer( + positions, + query, + key, + offsets=self.long_lora_indices[:self.indices_len[4]]) + + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self.base_layer.scaling_factor_to_offset + + @classmethod + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + return type(source_layer) is LinearScalingRotaryEmbedding or type( + source_layer) is RotaryEmbedding + + def extra_repr(self) -> str: + return self.base_layer.extra_repr() diff --git a/vllm/lora/models.py b/vllm/lora/models.py index cd45040bcca5d..d001d17144d98 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -3,7 +3,8 @@ import math import os import re -from typing import Callable, Dict, List, Optional, Tuple, Type +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional, Tuple, Type, Union import safetensors.torch import torch @@ -11,7 +12,9 @@ from vllm.config import LoRAConfig from vllm.logger import init_logger -from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping +from vllm.lora.layers import (BaseLayerWithLoRA, + LinearScalingRotaryEmbeddingWithLora, + LoRAMapping) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) @@ -22,10 +25,27 @@ _GLOBAL_LORA_ID = 0 +@dataclass +class LongContextLoRAContext: + """Context for lora adapters that support long context.""" + # The scaling factors to support long context lora fine tuned models. + scaling_factors: List[float] + # dimension to apply rotary embedding. + rot_dim: int + # offsets to the sin_cos_cache for each lora_id loaded. + # This value is dynamically modified. + offsets_by_lora_id: Dict[int, int] = field(default_factory=dict) + + def convert_mapping( - mapping: LoRAMapping, lora_index_to_id: List[Optional[int]], - max_loras: int, vocab_size: int, extra_vocab_size: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + mapping: LoRAMapping, + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional[LongContextLoRAContext] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], List[int]]: """Converts LoRAMapping to index tensors. Args: @@ -34,6 +54,7 @@ def convert_mapping( max_loras: Maximum number of LoRAs. vocab_size: Model vocab size. extra_vocab_size: Extra vocab size each LoRA can have. + long_lora_context: Passed if there are long context lora in a batch. Returns: A tuple of tensors: @@ -51,11 +72,23 @@ def convert_mapping( requests to embedding indices. First row is for embeddings added by the LoRAs, second row is for the LoRA.lora_a embeddings. + long_lora_indices: Tensor of shape [batch_size] mapping + requests to RoPE offsets and rot dims for long LoRAs. + None if long context lora doesn't exist. indices_len: List of lengths of the above tensors. + Used to index into each tensor. It contains length for + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, long_lora_indices). If long_lora doesn't + exist, it only contains first 4 entries. """ index_mapping_indices: List[int] = list(mapping.index_mapping).copy() embedding_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy() + long_lora_offsets: Optional[torch.Tensor] = None + if long_lora_context: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device="cuda", + dtype=torch.long) prompt_mapping: List[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping @@ -66,13 +99,22 @@ def convert_mapping( lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) if index_mapping_indices[i] > 0 else -1) embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 - index_mapping_indices[i] = i lora_indices[i] = lora_idx - - indices = torch.tensor( - [index_mapping_indices, lora_indices, embedding_indices], - dtype=torch.long, - device="cuda") + if long_lora_context: + assert long_lora_offsets is not None + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets[i] = lora_offset + # SANG-TODO + # index_mapping_indices[i] = i + + indices_list: List[Union[List[int], torch.Tensor]] = [ + index_mapping_indices, lora_indices, embedding_indices + ] + if long_lora_context: + assert long_lora_offsets is not None + indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") prompt_mapping_tensor = torch.tensor(prompt_mapping, device="cuda", dtype=torch.long) @@ -89,13 +131,21 @@ def convert_mapping( torch.arange( 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (sampler_indices_padded * len(sampler_indices_padded))) + long_lora_indices = None + long_lora_indices_len: Optional[int] = None + if long_lora_context: + long_lora_indices = indices[3] + long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. indices_len = [ base_indices.shape[-1], sampler_indices.shape[-1], sampler_indices_padded.shape[-1], embeddings_indices.shape[-1] ] + if long_lora_indices_len is not None: + indices_len.append(long_lora_indices_len) return (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, indices_len) + embeddings_indices, long_lora_indices, indices_len) def get_lora_id(): @@ -112,8 +162,20 @@ def __init__( lora_model_id: int, rank: int, loras: Dict[str, LoRALayerWeights], + scaling_factor: Optional[float] = None, ) -> None: + """ + Args: + lora_model_id: The integer id for the lora model. + rank: lora rank. + loras: module name -> weights for lora-replaced layers. + scaling_factor: Scaling factor to support long context lora model. + None if the lora is not tuned for long context support. + """ self.id = lora_model_id + # Scaling factor for long context lora model. None if it is not + # fine tuned for the long context. + self.scaling_factor = scaling_factor assert (lora_model_id > 0), f"a valid lora id should be greater than 0, got {self.id}" self.rank = rank @@ -150,6 +212,7 @@ def from_lora_tensors( dtype: Optional[torch.dtype] = None, embeddings: Optional[Dict[str, torch.Tensor]] = None, target_embedding_padding: Optional[int] = None, + scaling_factor: Optional[float] = None, embedding_modules: Optional[Dict[str, str]] = None, embedding_padding_modules: Optional[List[str]] = None, ) -> "LoRAModel": @@ -199,13 +262,15 @@ def from_lora_tensors( for lora in loras.values(): lora.optimize() - return cls(lora_model_id, rank, loras) + return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor) @classmethod def from_local_checkpoint( cls, lora_dir: str, expected_lora_modules: List[str], + *, + max_position_embeddings: Optional[int] = None, lora_model_id: Optional[int] = None, device: str = "cuda", dtype: Optional[torch.dtype] = None, @@ -213,7 +278,23 @@ def from_local_checkpoint( embedding_modules: Optional[Dict[str, str]] = None, embedding_padding_modules: Optional[List[str]] = None, ) -> "LoRAModel": - """Create a LoRAModel from a local checkpoint.""" + """Create a LoRAModel from a local checkpoint. + + Args: + lora_dir: The local path that has lora data. + expected_lora_modules: Name of modules that are expected to be + replaced by lora. + max_position_embeddings: Max position embedding length. Used to + scaling the largest context length. If None, the lora model's + context length is not scaled. + lora_model_id: Lora model id. If not given, automatically set by + a global counter. + device: Device where the lora model is loaded. + dtype: dtype of the lora model weights. + + Returns: + Loaded LoRA Model. + """ lora_config_path = os.path.join(lora_dir, "adapter_config.json") lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") @@ -253,6 +334,14 @@ def from_local_checkpoint( rank = config["r"] lora_alpha = config["lora_alpha"] + context_length = config.get("context_length", None) + scaling_factor = None + if context_length: + if max_position_embeddings is None: + max_position_embeddings = context_length + scaling_factor = float( + math.ceil(context_length / max_position_embeddings)) + return cls.from_lora_tensors( lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id, @@ -263,6 +352,7 @@ def from_local_checkpoint( dtype=dtype, embeddings=embeddings, target_embedding_padding=target_embedding_padding, + scaling_factor=scaling_factor, embedding_modules=embedding_modules, embedding_padding_modules=embedding_padding_modules, ) @@ -296,6 +386,7 @@ def __init__( self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size + self.long_lora_context: Optional[LongContextLoRAContext] = None self.base_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, device="cuda") @@ -309,6 +400,12 @@ def __init__( self.max_num_batched_tokens, dtype=torch.long, device="cuda") + self.long_lora_indices = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + # Scaling factor -> offset to the sin_cos_cache to it. + # Used for long context lora. + self.scaling_factor_to_offset: Dict[float, int] = {} # 4 is the number of indicies tensors defined above # base_indices, sampler_indices, sampler_indices_padded, # embeddings_indices @@ -318,6 +415,10 @@ def __init__( if hasattr(self.model, "supported_lora_modules"): self.supported_lora_modules = copy.deepcopy( self.model.supported_lora_modules) + if lora_config.long_lora_scaling_factors: + # We need to replace rotary emb layer to do batch computation + # for long lora. + self.supported_lora_modules.append("rotary_emb") self.packed_modules_mapping = copy.deepcopy( self.model.packed_modules_mapping) self.packed_modules: Dict[str, List[str]] = {} @@ -383,12 +484,32 @@ def deactivate_lora(self, lora_id: int) -> bool: return True return False + def _set_long_lora_context(self, lora: LoRAModel): + if self.long_lora_context is None: + return + + if lora.scaling_factor is None: + return + + if (lora.scaling_factor not in self.scaling_factor_to_offset): + raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}" + " has not been initialized.") + + offsets = self.scaling_factor_to_offset.get(lora.scaling_factor) + if offsets: + self.long_lora_context.offsets_by_lora_id[lora.id] = offsets + def _add_lora(self, lora: LoRAModel): self._create_merged_loras_inplace(lora) self._registered_loras[lora.id] = lora + self._set_long_lora_context(lora) def add_lora(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager CPU cache.""" + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) if lora.id not in self._registered_loras: if len(self._registered_loras) >= self.capacity: raise RuntimeError("No free LoRA slots.") @@ -400,15 +521,18 @@ def remove_lora(self, lora_id: int) -> bool: """Remove a LoRAModel from the manager CPU cache.""" # TODO: should we check active lora? self.deactivate_lora(lora_id) + if self.long_lora_context: + self.long_lora_context.offsets_by_lora_id.pop(lora_id, None) return bool(self._registered_loras.pop(lora_id, None)) # TODO see if this can be vectorized def _set_lora_mapping(self, mapping: LoRAMapping) -> None: (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, + embeddings_indices, long_lora_offsets_tensor, indices_len) = convert_mapping(mapping, self.lora_index_to_id, self.lora_slots + 1, self.vocab_size, - self.lora_config.lora_extra_vocab_size) + self.lora_config.lora_extra_vocab_size, + self.long_lora_context) self.base_indices[:base_indices.shape[0]].copy_(base_indices) self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( @@ -416,6 +540,11 @@ def _set_lora_mapping(self, mapping: LoRAMapping) -> None: self.embeddings_indices[:embeddings_indices. shape[0], :embeddings_indices.shape[1]].copy_( embeddings_indices) + if long_lora_offsets_tensor is not None: + self.long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self.long_lora_indices.zero_() # Maintain the reference self.indices_len[:] = indices_len @@ -438,7 +567,8 @@ def remove_all_loras(self): self._active_loras.clear() def _create_lora_modules(self): - for module_name, module in self.model.named_modules(): + for module_name, module in self.model.named_modules( + remove_duplicate=False): if not self._match_target_modules(module_name): continue parts = module_name.split(".")[-1] @@ -447,6 +577,13 @@ def _create_lora_modules(self): self.model, module_name, from_layer(module, self.lora_slots, self.lora_config, packed_moduled_lst, self.model.config)) + # LinearScalingRotaryEmbeddingWithLora is used to handle + # long context lora. Register relevant metadata. + if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora): + self.long_lora_context = LongContextLoRAContext( + new_module.scaling_factors, new_module.rotary_dim) + self.scaling_factor_to_offset = \ + new_module.scaling_factor_to_offset # (yard1): TODO make this more robust if "lm_head" in module_name: logits_processor_module = self.model.get_submodule( @@ -461,7 +598,8 @@ def _create_lora_modules(self): self._register_packed_modules(module_name) new_module.set_mapping(self.base_indices, self.sampler_indices, self.sampler_indices_padded, - self.embeddings_indices, self.indices_len) + self.embeddings_indices, + self.long_lora_indices, self.indices_len) def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): assert isinstance(module, BaseLayerWithLoRA) @@ -471,12 +609,14 @@ def create_dummy_lora( self, lora_id: int, rank: int, + scaling_factor: Optional[float], embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel: """Create zero-initialized LoRAModel for warmup.""" - model = LoRAModel(lora_id, rank, {}) + model = LoRAModel(lora_id, rank, {}, scaling_factor) for module_name, module in self.model.named_modules(): if not self._match_target_modules(module_name) or not isinstance( - module, BaseLayerWithLoRA): + module, BaseLayerWithLoRA) or isinstance( + module, LinearScalingRotaryEmbeddingWithLora): continue parts = module_name.split(".") if module_name not in self.packed_modules: @@ -606,6 +746,10 @@ def list_loras(self) -> Dict[int, LoRAModel]: def add_lora(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) if lora.id not in self._registered_loras: self._add_lora(lora) was_added = True diff --git a/vllm/lora/request.py b/vllm/lora/request.py index bbbf4880ab81b..662774ffe09ae 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional @dataclass @@ -18,6 +19,7 @@ class LoRARequest: lora_name: str lora_int_id: int lora_local_path: str + long_lora_max_len: Optional[int] = None def __post_init__(self): if self.lora_int_id < 1: diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 9942a5fd40dec..fcc7f24721939 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -13,6 +13,7 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LinearScalingRotaryEmbeddingWithLora, LogitsProcessorWithLoRA, MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLora, @@ -26,12 +27,18 @@ logger = init_logger(__name__) _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { - VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA, - MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora, - MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA, - LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA, + VocabParallelEmbeddingWithLoRA, + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + QKVParallelLinearWithLora, + MergedQKVParallelLinearWithLora, + RowParallelLinearWithLoRA, + LogitsProcessorWithLoRA, + ColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA + MergedQKVParallelLinearWithShardedLora, + RowParallelLinearWithShardedLoRA, + LinearScalingRotaryEmbeddingWithLora, } diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 377f561cceaf2..d67ce67172e30 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod, abstractproperty from contextlib import contextmanager -from typing import Any, Dict, List, Literal, Set, Type, Union +from typing import Any, Dict, List, Literal, Optional, Set, Type, Union import torch @@ -17,11 +17,16 @@ class AbstractWorkerLoRAManager(ABC): """Abstract class for managing LoRA models on the worker side.""" - def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, - vocab_size: int, lora_config: LoRAConfig, - device: torch.device): + def __init__(self, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + max_position_embeddings: Optional[int] = None): self.max_num_seqs = max_num_seqs self.max_num_batched_tokens = max_num_batched_tokens + self.max_position_embeddings = max_position_embeddings self.vocab_size = vocab_size self.device = device self.lora_config = lora_config @@ -92,14 +97,21 @@ def __init__( embedding_modules: Dict[str, str], embedding_padding_modules: List[str], lora_model_cls: Type[LoRAModel] = LoRAModel, + max_position_embeddings: Optional[int] = None, ): self._lora_model_cls = lora_model_cls self.embedding_modules = embedding_modules self.embedding_padding_modules = embedding_padding_modules # Lazily initialized by create_lora_manager. self._lora_manager: LoRAModelManager - super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size, - lora_config, device) + super().__init__( + max_num_seqs, + max_num_batched_tokens, + vocab_size, + lora_config, + device, + max_position_embeddings=max_position_embeddings, + ) @property def is_enabled(self) -> bool: @@ -162,6 +174,7 @@ def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: lora = self._lora_model_cls.from_local_checkpoint( lora_request.lora_local_path, expected_lora_modules, + max_position_embeddings=self.max_position_embeddings, lora_model_id=lora_request.lora_int_id, device="cpu", dtype=self.lora_config.lora_dtype, @@ -191,7 +204,7 @@ def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: lora_request.lora_int_id) else: dummy_lora = self._lora_manager.create_dummy_lora( - lora_request.lora_int_id, rank, self.embedding_modules) + lora_request.lora_int_id, rank, 1, self.embedding_modules) if self._cached_dummy_lora is None: self._cached_dummy_lora = dummy_lora return self._lora_manager.add_lora(dummy_lora) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 4758ca9660083..d03903d206d33 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -61,6 +61,7 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.base = base self.is_neox_style = is_neox_style + self.dtype = dtype cache = self._compute_cos_sin_cache() cache = cache.to(dtype) @@ -168,6 +169,29 @@ def extra_repr(self) -> str: class LinearScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with linear scaling. + It supports multiple scaling factors. Since multiple LoRA adapters may have + different scaling factors, we need multiple cos/sin caches. In this way, + instead of running rotary embedding kernel per lora, we can run multiple + lora in a batched way. + + In addition to that, we also keep the cos/sin cache for the scaling factor + of 1 (default) at all times. + + Exemplary for two scaling factors x=1, y and z with embeddings + [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and + [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and + [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]], + + we construct the cos/sin cache as follows: + [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p], + ... + [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]] + + We then use offsets to index into the cos/sin cache for + the respective scaling factors. + + The offset to cache can be accessed via `scaling_factor_to_offset` API. + Credits to the Reddit user /u/kaiokendev """ @@ -183,13 +207,18 @@ def __init__( ) -> None: if isinstance(scaling_factors, float): scaling_factors = [scaling_factors] - self.scaling_factors = scaling_factors + self.scaling_factors: List[float] = scaling_factors # noqa super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) + # Lazy initialized. + self._scaling_factor_to_offset: Dict[float, int] def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.base) - cache_list = [] + cache_list: List[torch.Tensor] = [] + # offsets to the next cache in a tensor. + # Each offset corresponds to the same index in scaling_factors. + offsets: List[int] = [] for scaling_factor in self.scaling_factors: # NOTE(woosuk): self.max_position_embeddings is the original # maximum length before applying the rope scaling. @@ -203,9 +232,25 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) + if not cache_list: + offset = 0 + else: + last_offset = offsets[-1] + next_max_len = cache_list[-1].shape[0] + offset = last_offset + next_max_len + offsets.append(offset) cache_list.append(cache) + self._scaling_factor_to_offset = { + float(scaling_factor): offsets[i] + for i, scaling_factor in enumerate(self.scaling_factors) + } + assert len(self.scaling_factors) == len(offsets) return torch.cat(cache_list, dim=0) + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self._scaling_factor_to_offset + class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with Dynamic NTK scaling. diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 29c76682109c6..ed65d76f7b5b9 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -348,6 +348,8 @@ def __init__( super().__init__() self.config: ChatGLMConfig = config self.quant_config = quant_config + self.max_position_embeddings = getattr(config, "max_sequence_length", + 8192) self.transformer = ChatGLMModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.output_layer.weight self.logits_processor = LogitsProcessor(config.padded_vocab_size) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index ebdc64e0e220e..f2996c240aaf4 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -321,12 +321,8 @@ class LlamaForCausalLM(nn.Module): # LoRA specific attributes supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - "embed_tokens", - "lm_head", + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", + "lm_head" ] embedding_modules = { "embed_tokens": "input_embeddings", diff --git a/vllm/transformers_utils/configs/chatglm.py b/vllm/transformers_utils/configs/chatglm.py index c4244f8c77f44..49d2b8d8e21b1 100644 --- a/vllm/transformers_utils/configs/chatglm.py +++ b/vllm/transformers_utils/configs/chatglm.py @@ -46,6 +46,8 @@ def __init__(self, self.kv_channels = kv_channels self.num_attention_heads = num_attention_heads self.seq_length = seq_length + # It is to be compatible with long lora. + self.max_position_embeddings = seq_length self.hidden_dropout = hidden_dropout self.attention_dropout = attention_dropout self.layernorm_epsilon = layernorm_epsilon diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 927cbeed073bf..9614f01d2b955 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -34,12 +34,26 @@ def get_max_input_len(self, """Get the maximum input length for the LoRA request.""" return self.max_input_length + def _raise_if_input_too_long(self, + encoded_tokens: List[str], + lora_request: Optional[LoRARequest] = None): + input_length = len(encoded_tokens) + if lora_request: + max_input_length = (lora_request.long_lora_max_len + or self.max_input_length) + else: + max_input_length = self.max_input_length + if max_input_length is not None and input_length > max_input_length: + raise ValueError("Input too long.", input_length, max_input_length) + def encode(self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) - return tokenizer.encode(prompt) + ret = tokenizer.encode(prompt) + self._raise_if_input_too_long(ret, lora_request) + return ret async def encode_async( self, @@ -47,7 +61,9 @@ async def encode_async( request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) - return tokenizer.encode(prompt) + ret = tokenizer.encode(prompt) + self._raise_if_input_too_long(ret, lora_request) + return ret def get_lora_tokenizer( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cd7af25654b52..e264fede0ee64 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -156,9 +156,15 @@ def load_model(self) -> None: ), "Model does not have embedding_padding_modules" self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, self.vocab_size, - self.lora_config, self.device, self.model.embedding_modules, - self.model.embedding_padding_modules) + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=self.model.config. + max_position_embeddings, + ) self.model = self.lora_manager.create_lora_manager(self.model) if self.kv_cache_dtype == "fp8" and is_hip(): From f68470e803df575f294e67167b4b83adfe004cfa Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 19 May 2024 15:13:33 +0800 Subject: [PATCH 077/124] [Bugfix][Model] Add base class for vision-language models (#4809) --- tests/models/test_registry.py | 9 ++++ vllm/model_executor/model_loader/loader.py | 13 +++--- vllm/model_executor/models/llava.py | 48 +++++++++++----------- vllm/model_executor/models/vlm_base.py | 12 ++++++ 4 files changed, 53 insertions(+), 29 deletions(-) create mode 100644 tests/models/test_registry.py create mode 100644 vllm/model_executor/models/vlm_base.py diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py new file mode 100644 index 0000000000000..547ab10051f1b --- /dev/null +++ b/tests/models/test_registry.py @@ -0,0 +1,9 @@ +import pytest + +from vllm.model_executor.models import _MODELS, ModelRegistry + + +@pytest.mark.parametrize("model_cls", _MODELS) +def test_registry_imports(model_cls): + # Ensure all model classes can be imported successfully + ModelRegistry.load_model_cls(model_cls) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index dc568928b2859..d1ab207549790 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -26,11 +26,7 @@ download_weights_from_hf, filter_files_not_needed_for_inference, get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) -from vllm.model_executor.models.llava import LlavaForConditionalGeneration - -_VISION_MODEL_CLASSES = [ - LlavaForConditionalGeneration, -] +from vllm.model_executor.models.vlm_base import VisionLanguageModelBase logger = init_logger(__name__) @@ -73,7 +69,12 @@ def _get_model_initialization_kwargs( "but LoRA is enabled. Support for this model may " "be added in the future. If this is important to you, " "please open an issue on github.") - elif model_class in _VISION_MODEL_CLASSES: + elif issubclass(model_class, VisionLanguageModelBase): + if vision_language_config is None: + raise ValueError("Provide `image_input_type` and other vision " + "related configurations through LLM entrypoint " + "or engine arguments.") + extra_kwargs["vision_language_config"] = vision_language_config return extra_kwargs diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 3b99b337a2765..e8a5b6237d4db 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -19,6 +19,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from .vlm_base import VisionLanguageModelBase + _KEYS_TO_MODIFY_MAPPING = { "language_model.lm_head": "lm_head", "language_model.model": "language_model", @@ -40,7 +42,7 @@ def __init__(self, vision_hidden_size: int, text_hidden_size: int, text_hidden_size, bias=True) - def forward(self, image_features): + def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.linear_1(image_features) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) @@ -50,29 +52,31 @@ def forward(self, image_features): def _merge_vision_embeddings(input_ids: torch.Tensor, inputs_embeds: torch.Tensor, vision_embeddings: torch.Tensor, - image_token_id: int): + image_token_id: int) -> torch.Tensor: """In place merges in vision_embeddings with inputs_embeds.""" mask = (input_ids == image_token_id) - inputs_embeds[mask] = vision_embeddings.view(-1, + + image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1] + if mask.sum() != image_feature_size: + raise ValueError(f"image_feature_size should be {image_feature_size}, " + f"but found: {mask.sum()}") + + inputs_embeds[mask] = vision_embeddings.view(image_feature_size, vision_embeddings.shape[-1]) + return inputs_embeds -class LlavaForConditionalGeneration(nn.Module): + +class LlavaForConditionalGeneration(VisionLanguageModelBase): def __init__(self, - config: "LlavaConfig", + config: LlavaConfig, vision_language_config: VisionLanguageConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional["QuantizationConfig"] = None) -> None: - super().__init__() - self.config = config - - self.vision_language_config = vision_language_config + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__(vision_language_config) - assert self.vision_language_config, ( - "Provide `image_input_type` and other vision " - "related configurations through LLM entrypoint " - "or engine arguments.") + self.config = config if self.vision_language_config.image_input_type == ( VisionLanguageConfig.ImageInputType.PIXEL_VALUES): @@ -98,14 +102,12 @@ def __init__(self, config.vocab_size, logit_scale) self.sampler = Sampler() - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - image_input: Optional[torch.Tensor] = None - ) -> SamplerOutput: # noqa: E501 + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + image_input: Optional[torch.Tensor] = None) -> SamplerOutput: """Run forward pass for Llava 1.5. One key thing to understand is the `input_ids` already accounts for the @@ -172,7 +174,7 @@ def forward( image_features = image_input vision_embeddings = self.multi_modal_projector(image_features) inputs_embeds = self.language_model.get_input_embeddings(input_ids) - _merge_vision_embeddings( + inputs_embeds = _merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, self.vision_language_config.image_token_id) input_ids = None diff --git a/vllm/model_executor/models/vlm_base.py b/vllm/model_executor/models/vlm_base.py new file mode 100644 index 0000000000000..eb0aa96e50d59 --- /dev/null +++ b/vllm/model_executor/models/vlm_base.py @@ -0,0 +1,12 @@ +from torch import nn + +from vllm.config import VisionLanguageConfig + + +class VisionLanguageModelBase(nn.Module): + """Base class for all vision language models (VLMs).""" + + def __init__(self, vision_language_config: VisionLanguageConfig) -> None: + super().__init__() + + self.vision_language_config = vision_language_config From 27ce85476e6b170c5c90c65ac5c3268911135766 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Sun, 19 May 2024 11:37:34 -0400 Subject: [PATCH 078/124] [Kernel] Add marlin_24 unit tests (#4901) --- tests/kernels/test_marlin_gemm.py | 87 ++++- .../layers/quantization/gptq_marlin_24.py | 27 +- .../layers/quantization/utils/format_24.py | 308 ++++++++++++++++++ .../quantization/utils/marlin_24_perms.py | 58 ++++ .../layers/quantization/utils/marlin_perms.py | 58 ++++ .../layers/quantization/utils/marlin_utils.py | 214 +++++++----- 6 files changed, 649 insertions(+), 103 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/utils/format_24.py create mode 100644 vllm/model_executor/layers/quantization/utils/marlin_24_perms.py create mode 100644 vllm/model_executor/layers/quantization/utils/marlin_perms.py diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index b0ad85c25c572..587fc3901eb7c 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -7,23 +7,32 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.marlin_perms import ( + marlin_perm) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - MarlinWorkspace, is_marlin_supported, marlin_quantize, marlin_weights) + MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize, + marlin_quantize, marlin_weights) from vllm.model_executor.layers.quantization.utils.quant_utils import ( gptq_pack, quantize_weights, sort_weights) ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] -K_CHUNKS = [128, 256] -N_CHUNKS = [64, 128, 256] +MARLIN_K_CHUNKS = [128] +MARLIN_N_CHUNKS = [64, 128, 256] + +MARLIN_24_K_CHUNKS = [128] +MARLIN_24_N_CHUNKS = [256] MNK_FACTORS = [ (1, 1, 1), (1, 4, 8), (1, 7, 5), - (1, 7 * 4, 5 * 1), (13, 17, 67), (26, 37, 13), (67, 13, 11), @@ -31,14 +40,13 @@ def rand_data(shape): - data = torch.rand(shape).to(torch.half).cuda() - return data + return torch.randn(shape, dtype=torch.half, device="cuda") @pytest.mark.skipif(not is_marlin_supported(), reason="Marlin is not supported on this GPU type.") -@pytest.mark.parametrize("k_chunk", K_CHUNKS) -@pytest.mark.parametrize("n_chunk", N_CHUNKS) +@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS) @pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @@ -82,7 +90,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) # Pack to Marlin format - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits) + marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, + marlin_perm[num_bits]) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.gptq_marlin_repack( @@ -99,8 +108,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, @pytest.mark.skipif(not is_marlin_supported(), reason="Marlin is not supported on this GPU type.") -@pytest.mark.parametrize("k_chunk", K_CHUNKS) -@pytest.mark.parametrize("n_chunk", N_CHUNKS) +@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS) @pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @@ -136,7 +145,8 @@ def test_marlin_gemm( w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( b_weight, num_bits, group_size, act_order) - workspace = MarlinWorkspace(size_n) + workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) output = ops.gptq_marlin_gemm( a_input, @@ -155,4 +165,55 @@ def test_marlin_gemm( torch.cuda.synchronize() - assert torch.allclose(output, output_ref, rtol=1e-2) + max_diff = compute_max_diff(output, output_ref) + print("max_diff = {}".format(max_diff)) + + assert max_diff < 0.04 + + +@pytest.mark.skipif(not is_marlin_supported(), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS) +@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + print(f"MNK = {size_m} {size_n} {size_k}") + print(f"groupsize = {group_size}") + + a_input = rand_data((size_m, size_k)) + b_weight = rand_data((size_k, size_n)) + + (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, + marlin_24_s) = marlin_24_quantize(b_weight, num_bits, group_size) + + workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_MAX_PARALLEL) + + output_ref = torch.matmul(a_input, w_24_ref) + + output = ops.gptq_marlin_24_gemm( + a_input, + marlin_24_q_w_comp, + marlin_24_meta, + marlin_24_s, + workspace_24.scratch, + num_bits, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + ) + + torch.cuda.synchronize() + + max_diff = compute_max_diff(output, output_ref) + print("max_diff = {}".format(max_diff)) + + assert max_diff < 0.04 diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index 1bd6127104654..f5345c0443029 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -12,6 +12,15 @@ logger = init_logger(__name__) +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 16 + +GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] +GPTQ_MARLIN_24_SUPPORTED_SYM = [True] + class GPTQMarlin24Config(QuantizationConfig): """Config class for Marlin24. @@ -25,15 +34,17 @@ def __init__( self.weight_bits = weight_bits self.group_size = group_size - if self.weight_bits != 4 and self.weight_bits != 8: - raise ValueError("weight_bits must be 4 or 8. Got = {}".format( - self.weight_bits)) - - if self.group_size != 128 and self.group_size != -1: + # Verify + if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: + raise ValueError( + f"Marlin_24 does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} " + "are supported.") + if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: raise ValueError( - "Currently, only group size 128 and -1 (channelwise) " - "is supported for Marlin24, but got group_size of " - f"{self.group_size}") + f"Marlin_24 does not support group_size = {self.group_size}. " + f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} " + "are supported.") # 4 Bits packed into 32 bit datatype. self.pack_factor = 32 // self.weight_bits diff --git a/vllm/model_executor/layers/quantization/utils/format_24.py b/vllm/model_executor/layers/quantization/utils/format_24.py new file mode 100644 index 0000000000000..01c8cf789204b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/format_24.py @@ -0,0 +1,308 @@ +# +# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es). +# + +import torch + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, + device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError( + "Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16") + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32") + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, + idxs0.unsqueeze(-1) // 2).view( + m, + k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view( + (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12)) + elif quadbits_per_meta_elem == 8: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28)) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols, )) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix") + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta = torch.gather(meta_reordered.view(-1), 0, + meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view( + -1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_(0, dense_offsets, + sparse.view(torch.half).view(-1)) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " + f"{M} groups") + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask diff --git a/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py b/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py new file mode 100644 index 0000000000000..12e77cb710687 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py @@ -0,0 +1,58 @@ +"""This file is used for /tests and /benchmarks""" +import numpy +import torch + + +# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501 +# +# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 +# with the tensor-core format that is described here: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 +# +# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 +# (without the need to use ldmatrix instructions) # noqa: E501 +def get_perms_24(num_bits): + perm_list = [] + for i in range(32): + perm1 = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return perm, scale_perm, scale_perm_single + + +marlin_24_perm = {} +marlin_24_scale_perm = {} +marlin_24_scale_perm_single = {} +for num_bits in [4, 8]: + perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits) + marlin_24_perm[num_bits] = perm_24 + marlin_24_scale_perm[num_bits] = scale_perm_24 + marlin_24_scale_perm_single[num_bits] = scale_perm_single_24 diff --git a/vllm/model_executor/layers/quantization/utils/marlin_perms.py b/vllm/model_executor/layers/quantization/utils/marlin_perms.py new file mode 100644 index 0000000000000..76bd2ff7c724e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_perms.py @@ -0,0 +1,58 @@ +"""This file is used for /tests and /benchmarks""" +import numpy +import torch + + +# Precompute permutations for Marlin weight and scale shuffling # noqa: E501 +# +# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 +# with the tensor-core format that is described here: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 +# +# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 +# (without the need to use ldmatrix instructions) # noqa: E501 +def get_perms(num_bits): + perm_list = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm, scale_perm, scale_perm_single + + +marlin_perm = {} +marlin_scale_perm = {} +marlin_scale_perm_single = {} +for num_bits in [4, 8]: + perm, scale_perm, scale_perm_single = get_perms(num_bits) + marlin_perm[num_bits] = perm + marlin_scale_perm[num_bits] = scale_perm + marlin_scale_perm_single[num_bits] = scale_perm_single diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 33b3169983475..0d027d0620ab3 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -1,79 +1,28 @@ """This file is used for /tests and /benchmarks""" +import random + import numpy import torch -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_TILE) +from vllm.model_executor.layers.quantization.utils.format_24 import ( + mask_creator, sparse_semi_structured_from_dense_cutlass) +from vllm.model_executor.layers.quantization.utils.marlin_24_perms import ( + marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single) +from vllm.model_executor.layers.quantization.utils.marlin_perms import ( + marlin_perm, marlin_scale_perm, marlin_scale_perm_single) from vllm.model_executor.layers.quantization.utils.quant_utils import ( get_pack_factor, quantize_weights, sort_weights) __cuda_arch = torch.cuda.get_device_capability() +MARLIN_TILE = 16 + def is_marlin_supported(): return __cuda_arch[0] >= 8 -# Precompute permutations for Marlin weight and scale shuffling # noqa: E501 -# -# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 -# with the tensor-core format that is described here: -# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 -# -# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 -# (without the need to use ldmatrix instructions) # noqa: E501 -def _get_perms(num_bits): - perm_list = [] - for i in range(32): - perm1 = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - scale_perm = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return perm, scale_perm, scale_perm_single - - -_perm = {} -_scale_perm = {} -_scale_perm_single = {} -for num_bits in [4, 8]: - perm, scale_perm, scale_perm_single = _get_perms(num_bits) - _perm[num_bits] = perm - _scale_perm[num_bits] = scale_perm - _scale_perm_single[num_bits] = scale_perm_single - - -def marlin_permute_weights(q_w, - size_k, - size_n, - num_bits, - tile=GPTQ_MARLIN_TILE): +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE): assert q_w.shape == (size_k, size_n) assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" @@ -83,15 +32,14 @@ def marlin_permute_weights(q_w, q_w = q_w.permute((0, 2, 1, 3)) q_w = q_w.reshape((size_k // tile, size_n * tile)) - q_w = q_w.reshape( - (-1, _perm[num_bits].numel()))[:, _perm[num_bits]].reshape(q_w.shape) + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) return q_w -def marlin_weights(q_w, size_k, size_n, num_bits): +def marlin_weights(q_w, size_k, size_n, num_bits, perm): # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, num_bits) + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) # Pack pack_factor = get_pack_factor(num_bits) @@ -101,7 +49,6 @@ def marlin_weights(q_w, size_k, size_n, num_bits): q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=numpy.uint32) - for i in range(pack_factor): q_packed |= q_w[:, i::pack_factor] << num_bits * i @@ -110,15 +57,12 @@ def marlin_weights(q_w, size_k, size_n, num_bits): return q_packed -def marlin_permute_scales(s, size_k, size_n, group_size, num_bits): +def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm, + scale_perm_single): if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(_scale_perm[num_bits])))[:, - _scale_perm[num_bits]] + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] else: - s = s.reshape( - (-1, - len(_scale_perm_single[num_bits])))[:, - _scale_perm_single[num_bits]] + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] s = s.reshape((-1, size_n)).contiguous() return s @@ -148,8 +92,11 @@ def marlin_quantize( q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) # Reformat to marlin - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, + marlin_perm[num_bits]) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, + marlin_scale_perm[num_bits], + marlin_scale_perm_single[num_bits]) # Create result res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] @@ -159,15 +106,118 @@ def marlin_quantize( return res_list +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j:j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): + assert q_24.shape == (size_k, size_n) + + # Remove zp to normalize over 0 + max_q_val = (1 << num_bits) - 1 + zp = (max_q_val + 1) // 2 + q_24_no_zp = q_24 - zp + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( + q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore zp + q_24_comp = q_24_no_zp_comp + zp + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def marlin_24_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24, + num_bits, + group_size, + act_order=False) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, + num_bits) + size_k_comp = size_k // 2 + + # Reformat to marlin + marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, + num_bits, marlin_24_perm[num_bits]) + marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size, + marlin_24_scale_perm[num_bits], + marlin_24_scale_perm_single[num_bits]) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + + class MarlinWorkspace: - def __init__(self, out_features): - assert (out_features % GPTQ_MARLIN_MIN_THREAD_N == 0), ( - "out_features = {} is undivisible by GPTQ_MARLIN_MIN_THREAD_N = {}" - .format(out_features, GPTQ_MARLIN_MIN_THREAD_N)) + def __init__(self, out_features, min_thread_n, max_parallel): + assert (out_features % min_thread_n == 0), ( + "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n)) - max_workspace_size = ((out_features // GPTQ_MARLIN_MIN_THREAD_N) * - GPTQ_MARLIN_MAX_PARALLEL) + max_workspace_size = ((out_features // min_thread_n) * max_parallel) self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, From b57e6c59491ea7d60af413ad4a6455812b9c6c50 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 19 May 2024 18:11:30 -0700 Subject: [PATCH 079/124] [Kernel] Add flash-attn back (#4907) --- requirements-cuda.txt | 2 +- tests/kernels/test_flash_attn.py | 208 ++++++++++++++++++++++++++ tests/models/test_big_models.py | 2 +- tests/models/test_fp8.py | 10 +- vllm/attention/backends/flash_attn.py | 129 +++++++++------- vllm/attention/selector.py | 14 ++ 6 files changed, 304 insertions(+), 61 deletions(-) create mode 100644 tests/kernels/test_flash_attn.py diff --git a/requirements-cuda.txt b/requirements-cuda.txt index ba8c614d205d2..acb0164007dba 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,4 +7,4 @@ nvidia-ml-py # for pynvml package vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.3.0 xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 -vllm-flash-attn == 2.5.8.post1 # Requires PyTorch 2.3.0 +vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0 diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py new file mode 100644 index 0000000000000..22772d4ea4422 --- /dev/null +++ b/tests/kernels/test_flash_attn.py @@ -0,0 +1,208 @@ +from typing import List, Optional, Tuple + +import pytest +import torch +from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + +NUM_HEADS = [(16, 16), (32, 8), (64, 8)] +HEAD_SIZES = [128, 256] +BLOCK_SIZES = [16, 32] +DTYPES = [torch.float16, torch.bfloat16] +NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. + + +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: List[int], + kv_lens: List[int], + block_tables: torch.Tensor, + scale: float, + sliding_window: Optional[int] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + q = query[start_idx:start_idx + query_len] + q *= scale + + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + k = k[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + v = v[:kv_len] + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if sliding_window is not None: + sliding_window_mask = torch.triu(empty_mask, + diagonal=kv_len - + (query_len + sliding_window) + + 1).bool().logical_not() + mask |= sliding_window_mask + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) + + +@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_flash_attn_with_paged_kv( + kv_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_seqs = len(kv_lens) + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + key_cache = torch.randn(NUM_BLOCKS, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + value_cache = torch.randn_like(key_cache) + kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + output = flash_attn_with_kvcache( + q=query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + softmax_scale=scale, + causal=True, + block_table=block_tables, + cache_seqlens=kv_lens_tensor, + ).squeeze(1) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" + + +@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None]) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_varlen_with_paged_kv( + seq_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_size: int, + sliding_window: Optional[int], + dtype: torch.dtype, + block_size: int, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_query_len = max(query_lens) + max_kv_len = max(kv_lens) + window_size = ((sliding_window, + sliding_window) if sliding_window is not None else + (-1, -1)) + scale = head_size**-0.5 + + query = torch.randn(sum(query_lens), + num_query_heads, + head_size, + dtype=dtype) + key_cache = torch.randn(NUM_BLOCKS, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + value_cache = torch.randn_like(key_cache) + # Normalize the scale of the key and value caches to mitigate + # numerical instability. + key_cache /= head_size**0.5 + value_cache /= head_size**0.5 + cu_query_lens = torch.tensor([0] + query_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + cu_kv_lens = torch.tensor([0] + kv_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + output = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=block_tables, + ) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + sliding_window=sliding_window, + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index c02204f16ac68..10e7c64e34e75 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -12,7 +12,7 @@ # "Deci/DeciLM-7b", # Broken # "tiiuae/falcon-7b", # Broken "EleutherAI/gpt-j-6b", - "mosaicml/mpt-7b", + # "mosaicml/mpt-7b", # Broken # "Qwen/Qwen1.5-0.5B" # Broken, ] diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index e87a1783a83f1..664e951a89f2a 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -25,18 +25,18 @@ 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', + 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred to a', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**' + 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', ], "meta-llama/Meta-Llama-3-8B-Instruct": [ 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of', + 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 856f399741375..0361dd3bd4ead 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,19 +1,15 @@ -"""Attention layer with Flash and PagedAttention. - -NOTE(woosuk): At the moment, this file includes a lot of duplicated code from -XFormers backend. The duplicated code will be removed once we use flash-attn or -flashinfer for all the attention operations. -""" +"""Attention layer with FlashAttention.""" from dataclasses import dataclass from typing import List, Optional, Tuple, Type import torch -from vllm_flash_attn import flash_attn_varlen_func +from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from vllm._C import cache_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) + +_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256] class FlashAttentionBackend(AttentionBackend): @@ -37,8 +33,9 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -46,18 +43,26 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) @dataclass -class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): +class FlashAttentionMetadata(AttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -99,6 +104,14 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): # so far). context_lens_tensor: Optional[torch.Tensor] + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. @@ -219,11 +232,15 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + if head_size not in _SUPPORTED_HEAD_SIZES: raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.") def forward( self, @@ -234,17 +251,20 @@ def forward( attn_metadata: FlashAttentionMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: - """Forward pass with FlashAttention and PagedAttention. + """Forward pass with FlashAttention. Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. + assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -252,16 +272,20 @@ def forward( value = value.view(-1, self.num_kv_heads, self.head_size) if kv_cache is not None: - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) + key_cache = kv_cache[0] + value_cache = kv_cache[1] # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, kv_scale) + cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + ) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -281,7 +305,8 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or prefill_meta.block_tables.numel() == 0: + if (kv_cache is None or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. @@ -302,38 +327,34 @@ def forward( output[:num_prefill_tokens] = out else: # prefix-enabled attention - # TODO(Hai) this triton kernel has regression issue (broke) to - # deal with different data types between KV and FP8 KV cache, - # to be addressed separately. - output[:num_prefill_tokens] = PagedAttention.forward_prefix( - query, - key, - value, - key_cache, - value_cache, - prefill_meta.block_tables, - prefill_meta.query_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.context_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window[0], + assert prefill_meta.seq_lens is not None + max_seq_len = max(prefill_meta.seq_lens) + output[:num_prefill_tokens] = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, ) + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output[num_prefill_tokens:] = PagedAttention.forward_decode( - decode_query, + output[num_prefill_tokens:] = flash_attn_with_kvcache( + decode_query.unsqueeze(1), key_cache, value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - decode_meta.max_decode_seq_len, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - kv_scale, - ) + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + ).squeeze(1) # Reshape the output tensor. return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 06f99718a4dee..5140c3cc86a31 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -93,6 +93,20 @@ def _which_attn_to_use( "torch.float16 or torch.bfloat16.") return _Backend.XFORMERS + if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): + logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.") + return _Backend.XFORMERS + + if block_size % 16 != 0: + logger.info("Cannot use FlashAttention-2 backend for block size not " + "divisible by 16.") + return _Backend.XFORMERS + + if sliding_window is not None: + logger.info( + "Cannot use FlashAttention-2 backend due to sliding window.") + return _Backend.XFORMERS + try: import vllm_flash_attn # noqa: F401 except ImportError: From 6287537a0c970bda1fc8b31f2bde1bcf2d26e151 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 20 May 2024 16:11:25 +0800 Subject: [PATCH 080/124] [Model] LLaVA model refactor (#4910) --- vllm/model_executor/models/llava.py | 137 ++++++++++++++++++++++------ 1 file changed, 107 insertions(+), 30 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index e8a5b6237d4db..fbd7638097286 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union import torch from torch import nn @@ -67,6 +67,21 @@ def _merge_vision_embeddings(input_ids: torch.Tensor, return inputs_embeds +class LlavaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: (batch_size, num_channels, height, width)""" + + +class LlavaImageFeatureInputs(TypedDict): + type: Literal["image_features"] + data: torch.Tensor + """Shape: (batch_size, image_feature_size, hidden_size)""" + + +LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs] + + class LlavaForConditionalGeneration(VisionLanguageModelBase): def __init__(self, @@ -102,6 +117,90 @@ def __init__(self, config.vocab_size, logit_scale) self.sampler = Sampler() + def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor: + if list(data.shape[1:]) != list( + self.vision_language_config.image_input_shape[1:]): + raise ValueError( + f"The expected image tensor shape is batch dimension plus " + f"{self.vision_language_config.image_input_shape[1:]}. " + f"You supplied {data.shape}. " + f"If you are using vLLM's entrypoint, make sure your " + f"supplied image input is consistent with " + f"image_input_shape in engine args.") + + return data + + def _parse_and_validate_image_input( + self, data: object) -> Optional[LlavaImageInputs]: + expected_input_type = self.vision_language_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + + if data is None: + return None + + if expected_input_type == ImageInputType.PIXEL_VALUES: + if not isinstance(data, torch.Tensor): + raise TypeError("Image pixel vector should be a tensor, " + f"but received type: {type(data)}") + + return LlavaImagePixelInputs( + type="pixel_values", + data=self._validate_image_data(data), + ) + elif expected_input_type == ImageInputType.IMAGE_FEATURES: + if not isinstance(data, torch.Tensor): + raise TypeError("Image feature vector should be a tensor, " + f"but received type: {type(data)}") + + return LlavaImageFeatureInputs( + type="image_features", + data=self._validate_image_data(data), + ) + + return None + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _image_pixels_to_features(self, vision_tower: CLIPVisionModel, + pixel_values: torch.Tensor) -> torch.Tensor: + # TODO(xwjiang): Maybe port minimal CLIPVisionModel over. + image_outputs = vision_tower(pixel_values.to(vision_tower.device), + output_hidden_states=True) + + image_features = image_outputs.hidden_states[ + self.config.vision_feature_layer] + + return self._select_image_features( + image_features, + strategy=self.config.vision_feature_select_strategy, + ) + + def _process_image_pixels(self, + inputs: LlavaImagePixelInputs) -> torch.Tensor: + assert self.vision_tower is not None + + pixel_values = inputs["data"] + + return self._image_pixels_to_features(self.vision_tower, pixel_values) + + def _process_image_input(self, + image_input: LlavaImageInputs) -> torch.Tensor: + if image_input["type"] == "pixel_values": + assert self.vision_tower is not None + image_features = self._process_image_pixels(image_input) + else: + image_features = image_input["data"] + + return self.multi_modal_projector(image_features) + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, @@ -144,42 +243,20 @@ def forward(self, For PIXEL_VALUES, expecting [1, 3, 336, 336]. For IMAGE_FEATURES, expecting [1, 576, 1024]. """ - if image_input is not None: - if list(image_input.shape[1:]) != list( - self.vision_language_config.image_input_shape[1:]): - raise ValueError( - f"The expected image tensor shape is batch dimension " - f"plus " - f"{self.vision_language_config.image_input_shape[1:]}." - f" You supplied {image_input.shape}. " - f"If you are using vLLM's entrypoint, make sure your " - f"supplied image input is consistent with " - f"image_input_shape in engine args.") - if self.vision_tower is not None: - # TODO(xwjiang): Maybe port minimal CLIPVisionModel over. - image_outputs = self.vision_tower(image_input, - output_hidden_states=True) - image_features = image_outputs.hidden_states[ - self.config.vision_feature_layer] - # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa - if self.config.vision_feature_select_strategy == "default": - image_features = image_features[:, 1:] - elif self.config.vision_feature_select_strategy == "full": - image_features = image_features - else: - raise ValueError( - f"Unexpected select feature strategy: " - f"{self.config.vision_feature_select_strategy}") - else: - image_features = image_input - vision_embeddings = self.multi_modal_projector(image_features) + parsed_image_input = self._parse_and_validate_image_input(image_input) + + if parsed_image_input is not None: + vision_embeddings = self._process_image_input(parsed_image_input) inputs_embeds = self.language_model.get_input_embeddings(input_ids) + inputs_embeds = _merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, self.vision_language_config.image_token_id) + input_ids = None else: inputs_embeds = None + hidden_states = self.language_model(input_ids, positions, kv_caches, From da5a0b539d6a5fe0c0195513a797814d2c267540 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Mon, 20 May 2024 10:55:34 -0400 Subject: [PATCH 081/124] Remove marlin warning (#4918) --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 4 ---- 1 file changed, 4 deletions(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index fdc0ebef4672e..34950a5d13cf5 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -1519,10 +1519,6 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, } } - printf("WARNING: Marlin kernel is reducing max_m_blocks due to small SM " - "GPU cache. This may " - "hurt performance. Consider upgrading your GPU.\n"); - max_m_blocks--; // Process less M blocks per invocation to reduce cache // usage } From 546a97ef691f242c899a5e0906d0e75f42694e95 Mon Sep 17 00:00:00 2001 From: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> Date: Tue, 21 May 2024 01:45:06 +0800 Subject: [PATCH 082/124] [Misc]: allow user to specify port in distributed setting (#4914) --- vllm/envs.py | 7 +++++++ vllm/utils.py | 3 +++ 2 files changed, 10 insertions(+) diff --git a/vllm/envs.py b/vllm/envs.py index 68d8a074d0914..56ff79e0cdea9 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -3,6 +3,7 @@ if TYPE_CHECKING: VLLM_HOST_IP: str = "" + VLLM_PORT: Optional[int] = None VLLM_USE_MODELSCOPE: bool = False VLLM_INSTANCE_ID: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None @@ -96,6 +97,12 @@ 'VLLM_HOST_IP': lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""), + # used in distributed environment to manually set the communication port + # '0' is used to make mypy happy + 'VLLM_PORT': + lambda: int(os.getenv('VLLM_PORT', '0')) + if 'VLLM_PORT' in os.environ else None, + # If true, will load models from ModelScope instead of Hugging Face Hub. # note that the value is true or false, not numbers "VLLM_USE_MODELSCOPE": diff --git a/vllm/utils.py b/vllm/utils.py index f0e71f5e99b64..552b43e7f82b2 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -282,6 +282,9 @@ def get_distributed_init_method(ip: str, port: int) -> str: def get_open_port() -> int: + port = envs.VLLM_PORT + if port is not None: + return port # try ipv4 try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: From 943e72ca56974b4d8b5a141182e717d2abd3a819 Mon Sep 17 00:00:00 2001 From: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Date: Mon, 20 May 2024 13:29:28 -0500 Subject: [PATCH 083/124] [Build/CI] Enabling AMD Entrypoints Test (#4834) Co-authored-by: Alexey Kondratiev --- .buildkite/test-pipeline.yaml | 3 ++- Dockerfile.rocm | 8 ++++++-- requirements-rocm.txt | 3 ++- tests/spec_decode/e2e/conftest.py | 8 ++++++-- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6f5c46e23779f..def8a460e84a7 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -60,7 +60,8 @@ steps: command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py - label: Entrypoints Test - #mirror_hardwares: [amd] + mirror_hardwares: [amd] + commands: # these tests have to be separated, because each one will allocate all posible GPU memory - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py diff --git a/Dockerfile.rocm b/Dockerfile.rocm index eefad79e79d83..9bfe8446a519d 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -92,19 +92,23 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \ WORKDIR /vllm-workspace COPY . . +#RUN python3 -m pip install pynvml # to be removed eventually RUN python3 -m pip install --upgrade pip numba # make sure punica kernels are built (for LoRA) ENV VLLM_INSTALL_PUNICA_KERNELS=1 +# Workaround for ray >= 2.10.0 +ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 + +ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so RUN --mount=type=cache,target=/root/.cache/pip \ pip install -U -r requirements-rocm.txt \ && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ && python3 setup.py install \ && cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \ && cd .. -RUN python3 -m pip install --upgrade pip -RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3 CMD ["/bin/bash"] diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 903845b64d98f..cc42839a975d0 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -2,4 +2,5 @@ -r requirements-common.txt # Dependencies for AMD GPUs -ray == 2.9.3 +ray >= 2.10.0 +pytest-asyncio diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index da8b92711380e..7c5840baf3593 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -6,8 +6,12 @@ import pytest import ray import torch -from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, - nvmlInit) + +from vllm.utils import is_hip + +if (not is_hip()): + from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, + nvmlInit) from vllm import LLM from vllm.engine.arg_utils import AsyncEngineArgs From f0eecee6106774e1e0f9b31c7438cde77654df52 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 20 May 2024 21:44:25 +0300 Subject: [PATCH 084/124] [Bugfix] Fix dummy weight for fp8 (#4916) Allow dummy load format for fp8, torch.uniform_ doesn't support FP8 at the moment Co-authored-by: Mor Zusman --- vllm/model_executor/model_loader/weight_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index c1abde9af7701..a1642baa2c90c 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -369,4 +369,11 @@ def initialize_dummy_weights( """ for param in model.state_dict().values(): if torch.is_floating_point(param): - param.data.uniform_(low, high) + if torch.finfo(param.data.dtype).bits < 16: + # uniform_ doesn't support < 16-bit datatypes (FP8) + dtype = param.data.dtype + tmp_param = param.data.to(torch.float16) + tmp_param = tmp_param.uniform_(low, high).to(dtype) + param.data.copy_(tmp_param) + else: + param.uniform_(low, high) From 1937e29848c8de8634c5421612d57863aa0e2a51 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 20 May 2024 14:46:12 -0400 Subject: [PATCH 085/124] [Core] Sharded State Loader download from HF (#4889) --- vllm/model_executor/model_loader/loader.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index d1ab207549790..45ea8160a801b 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -423,6 +423,16 @@ def get_end_ptr(tensor: torch.Tensor) -> int: result[k] = t return result + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]): + if os.path.isdir(model_name_or_path): + return model_name_or_path + else: + allow_patterns = ["*.safetensors"] + return download_weights_from_hf(model_name_or_path, + self.load_config.download_dir, + allow_patterns, revision) + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], @@ -433,6 +443,10 @@ def load_model(self, *, model_config: ModelConfig, from safetensors.torch import safe_open from vllm.distributed import get_tensor_model_parallel_rank + + local_model_path = self._prepare_weights(model_config.model, + model_config.revision) + with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, @@ -440,7 +454,7 @@ def load_model(self, *, model_config: ModelConfig, cache_config) rank = get_tensor_model_parallel_rank() pattern = os.path.join( - model_config.model, + local_model_path, self.pattern.format(rank=rank, part="*"), ) filepaths = glob.glob(pattern) From c3af44722cff56bba5fc912c8e16d9de02dfb532 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 20 May 2024 13:16:57 -0700 Subject: [PATCH 086/124] [Doc]Add documentation to benchmarking script when running TGI (#4920) --- benchmarks/benchmark_serving.py | 4 ++++ benchmarks/launch_tgi_server.sh | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 2c2d69da4a7d1..9c3fed4817de2 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -17,6 +17,10 @@ --dataset-path \ --request-rate \ # By default is inf --num-prompts # By default is 1000 + + when using tgi backend, add + --endpoint /generate_stream + to the end of the command above. """ import argparse import asyncio diff --git a/benchmarks/launch_tgi_server.sh b/benchmarks/launch_tgi_server.sh index 64d3c4f4b3889..f491c90d0683e 100755 --- a/benchmarks/launch_tgi_server.sh +++ b/benchmarks/launch_tgi_server.sh @@ -4,7 +4,7 @@ PORT=8000 MODEL=$1 TOKENS=$2 -docker run --gpus all --shm-size 1g -p $PORT:80 \ +docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \ -v $PWD/data:/data \ ghcr.io/huggingface/text-generation-inference:1.4.0 \ --model-id $MODEL \ From 65ae8c2c8f52e0c98e4e26ad1255772d888592a6 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 20 May 2024 17:48:32 -0700 Subject: [PATCH 087/124] [Core] Fix scheduler considering "no LoRA" as "LoRA" (#4897) --- vllm/core/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c8da54f2889eb..7c70b1b244f7d 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -744,8 +744,8 @@ def _schedule_default(self) -> SchedulerOutputs: budget.add_num_seqs(seq_group.request_id, seq_group.get_max_num_running_seqs()) curr_loras = set( - seq_group.lora_int_id - for seq_group in self.running) if self.lora_enabled else None + seq_group.lora_int_id for seq_group in self.running + if seq_group.lora_int_id > 0) if self.lora_enabled else None remaining_waiting, prefills = (self.waiting, SchedulerPrefillOutputs.create_empty()) From d130b573a0162173002b97e2112c6c1c10d0ca8e Mon Sep 17 00:00:00 2001 From: HUANG Fei Date: Tue, 21 May 2024 13:22:22 +0800 Subject: [PATCH 088/124] [Model] add rope_scaling support for qwen2 (#4930) --- vllm/model_executor/models/qwen2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 31ba6441f9f7a..97ab6168c3230 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -89,7 +89,8 @@ def __init__(self, use_sliding_window: bool = False, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None) -> None: + sliding_window: Optional[int] = None, + rope_scaling: Optional[Tuple] = None) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -133,6 +134,7 @@ def __init__(self, rotary_dim=self.head_dim, max_position=max_position, base=self.rope_theta, + rope_scaling=rope_scaling, ) self.attn = Attention(self.num_heads, self.head_dim, @@ -169,6 +171,7 @@ def __init__( self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) use_sliding_window = (config.use_sliding_window and layer_idx < config.max_window_layers) self.self_attn = Qwen2Attention( @@ -180,7 +183,8 @@ def __init__( use_sliding_window=use_sliding_window, cache_config=cache_config, quant_config=quant_config, - sliding_window=config.sliding_window) + sliding_window=config.sliding_window, + rope_scaling=rope_scaling) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, From f12c3b5b3d076a67662b76d215fd875fd6cdf6d7 Mon Sep 17 00:00:00 2001 From: Isotr0py <41363108+Isotr0py@users.noreply.github.com> Date: Tue, 21 May 2024 13:24:17 +0800 Subject: [PATCH 089/124] [Model] Add Phi-2 LoRA support (#4886) --- docs/source/models/supported_models.rst | 2 +- tests/lora/conftest.py | 5 ++ tests/lora/test_phi.py | 67 +++++++++++++++++++++++++ vllm/model_executor/models/phi.py | 33 +++++++++--- 4 files changed, 100 insertions(+), 7 deletions(-) create mode 100644 tests/lora/test_phi.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 142c8f8573e2f..31d4b53bd4409 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -118,7 +118,7 @@ Alongside each architecture, we include some popular models that use it. * - :code:`PhiForCausalLM` - Phi - :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc. - - + - ✅︎ * - :code:`Phi3ForCausalLM` - Phi-3 - :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc. diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 5c648f72d8ddd..95fc65cdd1a8f 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -165,6 +165,11 @@ def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") +@pytest.fixture(scope="session") +def phi2_lora_files(): + return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") + + @pytest.fixture(scope="session") def long_context_lora_files_16k_1(): return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1") diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py new file mode 100644 index 0000000000000..a2b42ce4cb96f --- /dev/null +++ b/tests/lora/test_phi.py @@ -0,0 +1,67 @@ +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "microsoft/phi-2" + +PROMPT_TEMPLATE = "### Instruct: {sql_prompt}\n\n### Context: {context}\n\n### Output:" # noqa: E501 + + +def do_sample(llm, lora_path: str, lora_id: int) -> str: + prompts = [ + PROMPT_TEMPLATE.format( + sql_prompt= + "Which catalog publisher has published the most catalogs?", + context="CREATE TABLE catalogs (catalog_publisher VARCHAR);"), + PROMPT_TEMPLATE.format( + sql_prompt= + "Which trip started from the station with the largest dock count? Give me the trip id.", # noqa: E501 + context= + "CREATE TABLE trip (id VARCHAR, start_station_id VARCHAR); CREATE TABLE station (id VARCHAR, dock_count VARCHAR);" # noqa: E501 + ), + PROMPT_TEMPLATE.format( + sql_prompt= + "How many marine species are found in the Southern Ocean?", # noqa: E501 + context= + "CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));" # noqa: E501 + ), + ] + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=64, + stop="### End") + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None, + ) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def test_phi2_lora(phi2_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=2, + enforce_eager=True) + + expected_lora_output = [ + "SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501 + "SELECT trip.id FROM trip JOIN station ON trip.start_station_id = station.id WHERE station.dock_count = (SELECT MAX(dock_count) FROM station);", # noqa: E501 + "SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';", # noqa: E501 + ] + + output1 = do_sample(llm, phi2_lora_files, lora_id=1) + for i in range(len(expected_lora_output)): + assert output1[i].startswith(expected_lora_output[i]) + output2 = do_sample(llm, phi2_lora_files, lora_id=2) + for i in range(len(expected_lora_output)): + assert output2[i].startswith(expected_lora_output[i]) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index ed25a232f4208..193a29d20c894 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -42,7 +42,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -229,11 +229,32 @@ def forward( class PhiForCausalLM(nn.Module): - - def __init__(self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ] + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "dense", + "fc1", + "fc2", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + del lora_config # Unused. super().__init__() self.config = config self.quant_config = quant_config From e941f885843d4bcd239f805a9267729e9631556f Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Tue, 21 May 2024 00:17:25 -0700 Subject: [PATCH 090/124] [Docs] Add acknowledgment for sponsors (#4925) --- README.md | 25 +++++++++++++++++++++++++ docs/source/community/sponsors.md | 24 ++++++++++++++++++++++++ docs/source/index.rst | 1 + 3 files changed, 50 insertions(+) create mode 100644 docs/source/community/sponsors.md diff --git a/README.md b/README.md index fc3f71b00c3c5..627e45d4de0c4 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,31 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more. We welcome and value any contributions and collaborations. Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved. +## Sponsors + +vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support! + + + + +- a16z +- AMD +- Anyscale +- AWS +- Crusoe Cloud +- Databricks +- DeepInfra +- Lambda Lab +- NVIDIA +- Replicate +- Roblox +- RunPod +- Trainy +- UC Berkeley +- UC San Diego + +We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM. + ## Citation If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180): diff --git a/docs/source/community/sponsors.md b/docs/source/community/sponsors.md new file mode 100644 index 0000000000000..532ce77beb7b8 --- /dev/null +++ b/docs/source/community/sponsors.md @@ -0,0 +1,24 @@ +# Sponsors + +vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support! + + + + +- a16z +- AMD +- Anyscale +- AWS +- Crusoe Cloud +- Databricks +- DeepInfra +- Lambda Lab +- NVIDIA +- Replicate +- Roblox +- RunPod +- Trainy +- UC Berkeley +- UC San Diego + +We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM. \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index bab00e28e4018..5db1c9346c45d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -118,6 +118,7 @@ Documentation :caption: Community community/meetups + community/sponsors Indices and tables ================== From 757b62c49560baa6f294310a53032348a0d95939 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 21 May 2024 12:06:10 -0400 Subject: [PATCH 091/124] [CI/Build] Codespell ignore `build/` directory (#4945) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1c61a9e955b61..96f78c37cfefb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ exclude = [ [tool.codespell] ignore-words-list = "dout, te, indicies" -skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data" +skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build" [tool.isort] use_parentheses = true From 14772eeb8e8ec76e5e70142d12a7332fcec28ccb Mon Sep 17 00:00:00 2001 From: Kante Yin Date: Wed, 22 May 2024 00:30:52 +0800 Subject: [PATCH 092/124] [Bugfix] Fix flag name for `max_seq_len_to_capture` (#4935) Signed-off-by: kerthcet --- vllm/engine/arg_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d0edf0a75b710..1ba424c4eeb14 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -341,9 +341,9 @@ def add_cli_args( help='Maximum context length covered by CUDA ' 'graphs. When a sequence has context length ' 'larger than this, we fall back to eager mode. ' - '(DEPRECATED. Use --max-seq_len-to-capture instead' + '(DEPRECATED. Use --max-seq-len-to-capture instead' ')') - parser.add_argument('--max-seq_len-to-capture', + parser.add_argument('--max-seq-len-to-capture', type=int, default=EngineArgs.max_seq_len_to_capture, help='Maximum sequence length covered by CUDA ' From 99eff67ba9155b5fec9a9abd939e3a29a1b42dce Mon Sep 17 00:00:00 2001 From: Isotr0py <41363108+Isotr0py@users.noreply.github.com> Date: Wed, 22 May 2024 03:33:25 +0800 Subject: [PATCH 093/124] [Bugfix][Kernel] Add head size check for attention backend selection (#4944) --- vllm/attention/backends/flash_attn.py | 12 ++++++++---- vllm/attention/selector.py | 16 +++++++++++++--- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 0361dd3bd4ead..0f4568070cfc4 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -9,11 +9,13 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) -_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256] - class FlashAttentionBackend(AttentionBackend): + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod def get_name() -> str: return "flash-attn" @@ -237,10 +239,12 @@ def __init__( # paged KV cache. raise ValueError( "Sliding window is not supported in FlashAttention.") - if head_size not in _SUPPORTED_HEAD_SIZES: + + support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: raise ValueError( f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.") + f"Supported head sizes are: {support_head_sizes}.") def forward( self, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 5140c3cc86a31..51c25a81b4130 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -34,11 +34,21 @@ def get_attn_backend( sliding_window, dtype, kv_cache_dtype, block_size) if backend == _Backend.FLASH_ATTN: - logger.info("Using FlashAttention-2 backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) - return FlashAttentionBackend - elif backend == _Backend.XFORMERS: + + # We check it here not in _which_attn_to_use because we cannot know + # the head size until we import FlashAttentionBackend. + supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size in supported_head_sizes: + logger.info("Using FlashAttention-2 backend.") + return FlashAttentionBackend + logger.info( + "Cannot use FlashAttention-2 backend for head size %d. " + "Using XFormers backend instead.", head_size) + backend = _Backend.XFORMERS + + if backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 XFormersBackend) From 9b9a10d6cb89f18e054daa66f25cb8f17c723b2c Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Wed, 22 May 2024 05:32:35 +0000 Subject: [PATCH 094/124] [Frontend] Dynamic RoPE scaling (#4638) --- tests/test_config.py | 56 ++++++++++++++++++++++++++++++- vllm/config.py | 7 +++- vllm/engine/arg_utils.py | 18 +++++++--- vllm/engine/llm_engine.py | 10 +++--- vllm/transformers_utils/config.py | 10 +++++- 5 files changed, 89 insertions(+), 12 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 19db10630bbae..6bc51a53dc07c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -36,4 +36,58 @@ def test_get_sliding_window(): assert mistral_model_config.get_sliding_window() is None mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW - assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW \ No newline at end of file + assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW + + +def test_rope_scaling(): + TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0} + LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0} + + llama_model_config = ModelConfig( + "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3-8B-Instruct", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + ) + assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None + assert llama_model_config.max_model_len == 8192 + + llama_model_config = ModelConfig( + "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3-8B-Instruct", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + rope_scaling=TEST_ROPE_SCALING, + ) + assert getattr(llama_model_config.hf_config, "rope_scaling", + None) == TEST_ROPE_SCALING + assert llama_model_config.max_model_len == 16384 + + longchat_model_config = ModelConfig( + "lmsys/longchat-13b-16k", + "lmsys/longchat-13b-16k", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + ) + assert getattr(longchat_model_config.hf_config, "rope_scaling", + None) == LONGCHAT_ROPE_SCALING + assert longchat_model_config.max_model_len == 16384 + + longchat_model_config = ModelConfig( + "lmsys/longchat-13b-16k", + "lmsys/longchat-13b-16k", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + rope_scaling=TEST_ROPE_SCALING, + ) + assert getattr(longchat_model_config.hf_config, "rope_scaling", + None) == TEST_ROPE_SCALING + assert longchat_model_config.max_model_len == 4096 diff --git a/vllm/config.py b/vllm/config.py index 44ed5635f9a35..3256c11967914 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -45,6 +45,9 @@ class ModelConfig: code_revision: The specific revision to use for the model code on Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. + rope_scaling: Dictionary containing the scaling configuration for the + RoPE embeddings. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. tokenizer_revision: The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. @@ -84,6 +87,7 @@ def __init__( seed: int, revision: Optional[str] = None, code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, quantization: Optional[str] = None, @@ -102,6 +106,7 @@ def __init__( self.seed = seed self.revision = revision self.code_revision = code_revision + self.rope_scaling = rope_scaling self.tokenizer_revision = tokenizer_revision self.quantization = quantization self.quantization_param_path = quantization_param_path @@ -116,7 +121,7 @@ def __init__( self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision) + code_revision, rope_scaling) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.max_model_len = _get_and_verify_max_len(self.hf_text_config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1ba424c4eeb14..0a9ec7472fbca 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,5 +1,6 @@ import argparse import dataclasses +import json from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -49,6 +50,7 @@ class EngineArgs: disable_log_stats: bool = False revision: Optional[str] = None code_revision: Optional[str] = None + rope_scaling: Optional[dict] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enforce_eager: bool = False @@ -330,6 +332,11 @@ def add_cli_args( 'None, we assume the model weights are not ' 'quantized and use `dtype` to determine the data ' 'type of the weights.') + parser.add_argument('--rope-scaling', + default=None, + type=json.loads, + help='RoPE scaling configuration in JSON format. ' + 'For example, {"type":"dynamic","factor":2.0}') parser.add_argument('--enforce-eager', action='store_true', help='Always use eager-mode PyTorch. If False, ' @@ -548,11 +555,12 @@ def create_engine_config(self, ) -> EngineConfig: model_config = ModelConfig( self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.dtype, self.seed, self.revision, - self.code_revision, self.tokenizer_revision, self.max_model_len, - self.quantization, self.quantization_param_path, - self.enforce_eager, self.max_context_len_to_capture, - self.max_seq_len_to_capture, self.max_logprobs, - self.skip_tokenizer_init, self.served_model_name) + self.code_revision, self.rope_scaling, self.tokenizer_revision, + self.max_model_len, self.quantization, + self.quantization_param_path, self.enforce_eager, + self.max_context_len_to_capture, self.max_seq_len_to_capture, + self.max_logprobs, self.skip_tokenizer_init, + self.served_model_name) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f6a5284093c1c..60e23d4df15bb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -104,10 +104,11 @@ def __init__( "Initializing an LLM engine (v%s) with config: " "model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, " - "max_seq_len=%d, download_dir=%r, load_format=%s, " - "tensor_parallel_size=%d, disable_custom_all_reduce=%s, " - "quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, " + "rope_scaling=%r, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, seed=%d, served_model_name=%s)", vllm.__version__, @@ -117,6 +118,7 @@ def __init__( model_config.skip_tokenizer_init, model_config.tokenizer_mode, model_config.revision, + model_config.rope_scaling, model_config.tokenizer_revision, model_config.trust_remote_code, model_config.dtype, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 1756c91a612f0..f36d84dbdf7f9 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -2,9 +2,12 @@ from transformers import AutoConfig, PretrainedConfig +from vllm.logger import init_logger from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, JAISConfig, MPTConfig, RWConfig) +logger = init_logger(__name__) + _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { "chatglm": ChatGLMConfig, "dbrx": DbrxConfig, @@ -18,7 +21,8 @@ def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None, - code_revision: Optional[str] = None) -> PretrainedConfig: + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None) -> PretrainedConfig: try: config = AutoConfig.from_pretrained( model, @@ -41,6 +45,10 @@ def get_config(model: str, config = config_class.from_pretrained(model, revision=revision, code_revision=code_revision) + if rope_scaling is not None: + logger.info("Updating rope_scaling from %r to %r", + getattr(config, "rope_scaling", None), rope_scaling) + config.update({"rope_scaling": rope_scaling}) return config From 5f6d10c14c17122e6d711a4829ee0ca672e07f6f Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 22 May 2024 03:18:41 -0400 Subject: [PATCH 095/124] [CI/Build] Enforce style for C++ and CUDA code with `clang-format` (#4722) --- .clang-format | 26 + .github/workflows/clang-format.yml | 42 + csrc/activation_kernels.cu | 139 +- csrc/attention/attention_generic.cuh | 19 +- csrc/attention/attention_kernels.cu | 636 ++- csrc/attention/attention_utils.cuh | 11 +- csrc/attention/dtype_bfloat16.cuh | 74 +- csrc/attention/dtype_float16.cuh | 92 +- csrc/attention/dtype_float32.cuh | 88 +- csrc/attention/dtype_fp8.cuh | 32 +- csrc/cache.h | 44 +- csrc/cache_kernels.cu | 288 +- csrc/cpu/activation.cpp | 60 +- csrc/cpu/attention.cpp | 411 +- csrc/cpu/cache.cpp | 53 +- csrc/cpu/layernorm.cpp | 32 +- csrc/cpu/pos_encoding.cpp | 66 +- csrc/cpu/pybind.cpp | 75 +- csrc/cuda_compat.h | 9 +- csrc/cuda_utils.h | 7 +- csrc/cuda_utils_kernels.cu | 40 +- csrc/custom_all_reduce.cu | 55 +- csrc/custom_all_reduce.cuh | 105 +- csrc/custom_all_reduce_test.cu | 38 +- csrc/dispatch_utils.h | 42 +- csrc/layernorm_kernels.cu | 242 +- csrc/moe/moe_ops.cpp | 3 +- csrc/moe/moe_ops.h | 8 +- csrc/moe_align_block_size_kernels.cu | 211 +- csrc/ops.h | 330 +- csrc/pos_encoding_kernels.cu | 229 +- csrc/pybind.cpp | 142 +- csrc/quantization/aqlm/gemm_kernels.cu | 536 +-- csrc/quantization/awq/dequantize.cuh | 138 +- csrc/quantization/awq/gemm_kernels.cu | 611 +-- .../cutlass_w8a8/scaled_mm_dq_c2x.cu | 38 +- .../cutlass_w8a8/scaled_mm_dq_c3x.cu | 22 +- .../cutlass_w8a8/scaled_mm_dq_entry.cu | 47 +- csrc/quantization/fp8/amd/hip_float8.h | 216 +- csrc/quantization/fp8/amd/hip_float8_impl.h | 520 +-- csrc/quantization/fp8/amd/quant_utils.cuh | 711 ++-- csrc/quantization/fp8/common.cu | 87 +- csrc/quantization/fp8/nvidia/quant_utils.cuh | 138 +- csrc/quantization/gptq/compat.cuh | 70 +- csrc/quantization/gptq/matrix_view.cuh | 503 +-- csrc/quantization/gptq/q_gemm.cu | 3441 ++++++++--------- csrc/quantization/gptq/qdq_2.cuh | 107 +- csrc/quantization/gptq/qdq_3.cuh | 246 +- csrc/quantization/gptq/qdq_4.cuh | 203 +- csrc/quantization/gptq/qdq_8.cuh | 34 +- csrc/quantization/gptq/qdq_util.cuh | 58 +- csrc/quantization/gptq_marlin/gptq_marlin.cu | 696 ++-- csrc/quantization/gptq_marlin/gptq_marlin.cuh | 50 +- .../gptq_marlin/gptq_marlin_dtypes.cuh | 89 +- .../gptq_marlin/gptq_marlin_repack.cu | 94 +- .../marlin/dense/marlin_cuda_kernel.cu | 460 ++- csrc/quantization/marlin/sparse/common/base.h | 12 +- csrc/quantization/marlin/sparse/common/mem.h | 64 +- csrc/quantization/marlin/sparse/common/mma.h | 107 +- .../marlin/sparse/marlin_24_cuda_kernel.cu | 446 ++- .../squeezellm/quant_cuda_kernel.cu | 63 +- csrc/reduction_utils.cuh | 20 +- format.sh | 57 +- requirements-dev.txt | 1 + 64 files changed, 6571 insertions(+), 6963 deletions(-) create mode 100644 .clang-format create mode 100644 .github/workflows/clang-format.yml diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000..7f9e6d720fae5 --- /dev/null +++ b/.clang-format @@ -0,0 +1,26 @@ +BasedOnStyle: Google +UseTab: Never +IndentWidth: 2 +ColumnLimit: 80 + +# Force pointers to the type for C++. +DerivePointerAlignment: false +PointerAlignment: Left + +# Reordering #include statements can (and currently will) introduce errors +SortIncludes: false + +# Style choices +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +IndentPPDirectives: BeforeHash + +IncludeCategories: + - Regex: '^<' + Priority: 4 + - Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/' + Priority: 3 + - Regex: '^"(qoda|\.\.)/' + Priority: 2 + - Regex: '.*' + Priority: 1 diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml new file mode 100644 index 0000000000000..e9b6e28fa6bcb --- /dev/null +++ b/.github/workflows/clang-format.yml @@ -0,0 +1,42 @@ +name: clang-format + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + clang-format: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install clang-format==18.1.5 + - name: Running clang-format + run: | + EXCLUDES=( + 'csrc/moe/topk_softmax_kernels.cu' + 'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' + 'csrc/punica/bgmv/bgmv_config.h' + 'csrc/punica/bgmv/bgmv_impl.cuh' + 'csrc/punica/bgmv/vec_dtypes.cuh' + 'csrc/punica/punica_ops.cu' + 'csrc/punica/type_convert.h' + ) + find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ + | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ + | xargs clang-format --dry-run --Werror \ No newline at end of file diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 24d972702c858..867f63f12de4b 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -10,11 +10,11 @@ namespace vllm { // Activation and gating kernel template. -template +template __global__ void act_and_mul_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., 2, d] - const int d) { + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); @@ -23,72 +23,66 @@ __global__ void act_and_mul_kernel( } } -template +template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) - return (T) (((float) x) / (1.0f + expf((float) -x))); + return (T)(((float)x) / (1.0f + expf((float)-x))); } -template +template __device__ __forceinline__ T gelu_kernel(const T& x) { // Equivalent to PyTorch GELU with 'none' approximation. // Refer to: // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 - const float f = (float) x; + const float f = (float)x; constexpr float ALPHA = M_SQRT1_2; - return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA))); + return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA))); } -template +template __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { // Equivalent to PyTorch GELU with 'tanh' approximation. // Refer to: // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 - const float f = (float) x; + const float f = (float)x; constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; constexpr float KAPPA = 0.044715; float x_cube = f * f * f; float inner = BETA * (f + KAPPA * x_cube); - return (T) (0.5f * f * (1.0f + ::tanhf(inner))); + return (T)(0.5f * f * (1.0f + ::tanhf(inner))); } -} // namespace vllm +} // namespace vllm // Launch activation and gating kernel. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "act_and_mul_kernel", \ - [&] { \ - vllm::act_and_mul_kernel><<>>( \ - out.data_ptr(), \ - input.data_ptr(), \ - d); \ - }); - -void silu_and_mul( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); + +void silu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } -void gelu_and_mul( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); } -void gelu_tanh_and_mul( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); } @@ -96,11 +90,11 @@ void gelu_tanh_and_mul( namespace vllm { // Element-wise activation kernel template. -template +template __global__ void activation_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., d] - const int d) { + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., d] + const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); @@ -108,54 +102,49 @@ __global__ void activation_kernel( } } -} // namespace vllm +} // namespace vllm // Launch element-wise activation kernel. -#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ - int d = input.size(-1); \ - int64_t num_tokens = input.numel() / d; \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "activation_kernel", \ - [&] { \ - vllm::activation_kernel><<>>( \ - out.data_ptr(), \ - input.data_ptr(), \ - d); \ - }); +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + int d = input.size(-1); \ + int64_t num_tokens = input.numel() / d; \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ + vllm::activation_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); namespace vllm { -template +template __device__ __forceinline__ T gelu_new_kernel(const T& x) { - const float x3 = (float) (x * x * x); - const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); - return ((T) 0.5) * x * (((T) 1.0) + t); + const float x3 = (float)(x * x * x); + const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); + return ((T)0.5) * x * (((T)1.0) + t); } -template +template __device__ __forceinline__ T gelu_fast_kernel(const T& x) { - const float f = (float) x; - const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); - return ((T) 0.5) * x * (((T) 1.0) + t); + const float f = (float)x; + const T t = + (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); + return ((T)0.5) * x * (((T)1.0) + t); } -} // namespace vllm +} // namespace vllm -void gelu_new( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_new(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); } -void gelu_fast( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_fast(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); } diff --git a/csrc/attention/attention_generic.cuh b/csrc/attention/attention_generic.cuh index 31fb401cbe2c1..62409c0cce93e 100644 --- a/csrc/attention/attention_generic.cuh +++ b/csrc/attention/attention_generic.cuh @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -22,31 +23,31 @@ namespace vllm { // A vector type to store Q, K, V elements. -template +template struct Vec {}; // A vector type to store FP32 accumulators. -template +template struct FloatVec {}; // Template vector operations. -template +template inline __device__ Acc mul(A a, B b); -template +template inline __device__ float sum(T v); -template +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template +template inline __device__ void zero(T& dst) { constexpr int WORDS = sizeof(T) / 4; union { @@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) { dst = tmp.raw; } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 41b337dd91d36..d6203174e7275 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -27,15 +28,15 @@ #ifdef USE_ROCM #include #include "../quantization/fp8/amd/quant_utils.cuh" - typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat16 __nv_bfloat16; #else #include "../quantization/fp8/nvidia/quant_utils.cuh" #endif #ifndef USE_ROCM -#define WARP_SIZE 32 + #define WARP_SIZE 32 #else -#define WARP_SIZE warpSize + #define WARP_SIZE warpSize #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) @@ -45,7 +46,7 @@ namespace vllm { // Utility function for attention softmax. -template +template inline __device__ float block_sum(float* red_smem, float sum) { // Decompose the thread index into warp / lane. int warp = threadIdx.x / WARP_SIZE; @@ -82,31 +83,28 @@ inline __device__ float block_sum(float* red_smem, float sum) { // TODO(woosuk): Merge the last two dimensions of the grid. // Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - typename cache_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - vllm::Fp8KVCacheDataType KV_DTYPE, - int PARTITION_SIZE = 0> // Zero means no partitioning. +template // Zero means no partitioning. __device__ void paged_attention_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const float kv_scale) { + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float kv_scale) { const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; @@ -118,22 +116,29 @@ __device__ void paged_attention_kernel( } const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); - const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; // [start_block_idx, end_block_idx) is the range of blocks to process. - const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); const int num_blocks = end_block_idx - start_block_idx; // [start_token_idx, end_token_idx) is the range of tokens to process. const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); const int num_tokens = end_token_idx - start_token_idx; constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; const int warp_idx = thread_idx / WARP_SIZE; @@ -143,13 +148,14 @@ __device__ void paged_attention_kernel( const int num_heads = gridDim.x; const int num_queries_per_kv = num_heads / num_kv_heads; const int kv_head_idx = head_idx / num_queries_per_kv; - const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + const float alibi_slope = + alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. - // The vector size is configured in such a way that the threads in a thread group - // fetch or compute 16 bytes at a time. - // For example, if the size of a thread group is 4 and the data type is half, - // then the vector size is 16 / (4 * sizeof(half)) == 2. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; @@ -163,18 +169,21 @@ __device__ void paged_attention_kernel( // Load the query to registers. // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... - // th vectors of the query, and so on. - // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the query, and the second thread + // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because + // q is split from a qkv tensor, it may not be contiguous. const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; #pragma unroll - for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); } - __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a + // memory wall right before we use q_vecs // Memory planning. extern __shared__ char shared_mem[]; @@ -193,44 +202,50 @@ __device__ void paged_attention_kernel( // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 - // because int32 can lead to overflow when this variable is multiplied by large numbers - // (e.g., kv_block_stride). - const int64_t physical_block_number = static_cast(block_table[block_idx]); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); // Load a key to registers. // Each thread in a thread group has a different part of the key. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th - // vectors of the key, and so on. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the key, and the second thread + // has 1, 5, 9, ... th vectors of the key, and so on. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; K_vec k_vecs[NUM_VECS_PER_THREAD]; #pragma unroll for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride - + physical_block_offset * x; + const cache_t* k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { - k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); } else { // Vector conversion from Quant_vec to K_vec. Quant_vec k_vec_quant = *reinterpret_cast( - k_ptr + offset1 * BLOCK_SIZE * x + offset2); - k_vecs[j] = fp8::scaled_convert(k_vec_quant, kv_scale); + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert( + k_vec_quant, kv_scale); } } // Compute dot product. // This includes a reduction across the threads in the same thread group. - float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); // Add the ALiBi bias if slopes are given. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; @@ -285,13 +300,12 @@ __device__ void paged_attention_kernel( // If partitioning is enabled, store the max logit and exp_sum. if (USE_PARTITIONING && thread_idx == 0) { - float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; + float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; *max_logits_ptr = qk_max; - float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; *exp_sums_ptr = exp_sum; } @@ -304,7 +318,8 @@ __device__ void paged_attention_kernel( constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float accs[NUM_ROWS_PER_THREAD]; @@ -315,18 +330,21 @@ __device__ void paged_attention_kernel( scalar_t zero_value; zero(zero_value); - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 - // because int32 can lead to overflow when this variable is multiplied by large numbers - // (e.g., kv_block_stride). - const int64_t physical_block_number = static_cast(block_table[block_idx]); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); + from_float(logits_vec, *reinterpret_cast(logits + token_idx - + start_token_idx)); - const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride; + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -337,14 +355,17 @@ __device__ void paged_attention_kernel( if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { v_vec = *reinterpret_cast(v_ptr + offset); } else { - V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. - v_vec = fp8::scaled_convert(v_quant_vec, kv_scale); + v_vec = fp8::scaled_convert(v_quant_vec, + kv_scale); } if (block_idx == num_seq_blocks - 1) { - // NOTE(woosuk): When v_vec contains the tokens that are out of the context, - // we should explicitly zero out the values since they may contain NaNs. - // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + // NOTE(woosuk): When v_vec contains the tokens that are out of the + // context, we should explicitly zero out the values since they may + // contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll for (int j = 0; j < V_VEC_SIZE; j++) { @@ -367,8 +388,8 @@ __device__ void paged_attention_kernel( accs[i] = acc; } - // NOTE(woosuk): A barrier is required because the shared memory space for logits - // is reused for the output. + // NOTE(woosuk): A barrier is required because the shared memory space for + // logits is reused for the output. __syncthreads(); // Perform reduction across warps. @@ -405,9 +426,9 @@ __device__ void paged_attention_kernel( // Write the final output. if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE - + partition_idx * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -419,79 +440,75 @@ __device__ void paged_attention_kernel( } // Grid: (num_heads, num_seqs, 1). -template< - typename scalar_t, - typename cache_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - vllm::Fp8KVCacheDataType KV_DTYPE> +template __global__ void paged_attention_v1_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const float kv_scale) { - paged_attention_kernel( - /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float kv_scale) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, + v_cache, num_kv_heads, scale, block_tables, seq_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, + kv_head_stride, kv_scale); } // Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - typename cache_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - vllm::Fp8KVCacheDataType KV_DTYPE, - int PARTITION_SIZE> +template __global__ void paged_attention_v2_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const float kv_scale) { - paged_attention_kernel( - exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, - q_stride, kv_block_stride, kv_head_stride, kv_scale); + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float kv_scale) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, + kv_block_stride, kv_head_stride, kv_scale); } // Grid: (num_heads, num_seqs). -template< - typename scalar_t, - int HEAD_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> +template __global__ void paged_attention_v2_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_partitions) { + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_partitions) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; @@ -499,9 +516,11 @@ __global__ void paged_attention_v2_reduce_kernel( const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); if (num_partitions == 1) { // No need to reduce. Only copy tmp_out to out. - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { out_ptr[i] = tmp_out_ptr[i]; } @@ -520,8 +539,9 @@ __global__ void paged_attention_v2_reduce_kernel( // Load max logits to shared memory. float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float max_logit = -FLT_MAX; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { const float l = max_logits_ptr[i]; @@ -550,9 +570,11 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = VLLM_SHFL_SYNC(max_logit, 0); // Load rescaled exp sums to shared memory. - float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); - const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + float* shared_exp_sums = + reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float global_exp_sum = 0.0f; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { float l = shared_max_logits[i]; @@ -565,61 +587,45 @@ __global__ void paged_attention_v2_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); // Aggregate tmp_out to out. - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; #pragma unroll for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { float acc = 0.0f; for (int j = 0; j < num_partitions; ++j) { - acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; } from_float(out_ptr[i], acc); } } -} // namespace vllm - -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \ - vllm::paged_attention_v1_kernel<<>>( \ - out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - seq_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, \ - kv_scale); +} // namespace vllm + +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel< \ + T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + kv_scale); // TODO(woosuk): Tune NUM_THREADS. -template< - typename T, - typename CACHE_T, - int BLOCK_SIZE, - vllm::Fp8KVCacheDataType KV_DTYPE, - int NUM_THREADS = 128> +template void paged_attention_v1_launcher( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& seq_lens, - int max_seq_len, - const c10::optional& alibi_slopes, - float kv_scale) { + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float kv_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -632,9 +638,10 @@ void paged_attention_v1_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); @@ -644,7 +651,8 @@ void paged_attention_v1_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_seq_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len @@ -683,19 +691,10 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ - paged_attention_v1_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - seq_lens, \ - max_seq_len, \ - alibi_slopes, \ - kv_scale); +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ + paged_attention_v1_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + seq_lens, max_seq_len, alibi_slopes, kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. @@ -716,74 +715,45 @@ void paged_attention_v1_launcher( } void paged_attention_v1( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& seq_lens, // [num_seqs] - int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale) { - - DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V1_LAUNCHER_BLOCK_SIZE) -} - -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - seq_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, \ - kv_scale); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - out_ptr, \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - seq_lens_ptr, \ - max_num_partitions); - -template< - typename T, - typename CACHE_T, - int BLOCK_SIZE, - vllm::Fp8KVCacheDataType KV_DTYPE, - int NUM_THREADS = 128, - int PARTITION_SIZE = 512> + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int num_kv_heads, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int block_size, int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale){ + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V1_LAUNCHER_BLOCK_SIZE)} +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ + value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, kv_scale); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ + max_num_partitions); + +template void paged_attention_v2_launcher( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& seq_lens, - int max_seq_len, - const c10::optional& alibi_slopes, - float kv_scale) { + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float kv_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -796,9 +766,10 @@ void paged_attention_v2_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -853,59 +824,50 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ - paged_attention_v2_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - seq_lens, \ - max_seq_len, \ - alibi_slopes, \ - kv_scale); +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ + paged_attention_v2_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ + kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v2( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& seq_lens, // [num_seqs] - int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale) { - DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int num_kv_heads, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int block_size, int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale) { + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V2_LAUNCHER_BLOCK_SIZE) } #undef WARP_SIZE diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index ff64c4bd8f80c..cdcee42748998 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -26,7 +27,7 @@ namespace vllm { // Q*K^T operation. -template +template inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { using A_vec = typename FloatVec::Type; // Compute the parallel products for Q*K^T (treat vector lanes separately). @@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { return qk; } -template +template struct Qk_dot { - template + template static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { return qk_dot_(q, k); } }; -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 31e0cee01d2e1..3cdcb95e08099 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -28,8 +30,8 @@ #include #include - typedef __hip_bfloat162 __nv_bfloat162; - typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat162 __nv_bfloat162; +typedef __hip_bfloat16 __nv_bfloat16; #endif #include @@ -50,37 +52,37 @@ struct bf16_8_t { }; // BF16 vector types for Q, K, V. -template<> +template <> struct Vec<__nv_bfloat16, 1> { using Type = __nv_bfloat16; }; -template<> +template <> struct Vec<__nv_bfloat16, 2> { using Type = __nv_bfloat162; }; -template<> +template <> struct Vec<__nv_bfloat16, 4> { using Type = bf16_4_t; }; -template<> +template <> struct Vec<__nv_bfloat16, 8> { using Type = bf16_8_t; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec<__nv_bfloat16> { using Type = float; }; -template<> +template <> struct FloatVec<__nv_bfloat162> { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = Float4_; }; -template<> +template <> struct FloatVec { using Type = Float8_; }; @@ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { assert(false); #else #ifndef USE_ROCM - return a + b; + return a + b; #else - return __hadd(a, b); + return __hadd(a, b); #endif #endif } @@ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { } // Vector multiplication. -template<> +template <> inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); @@ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #endif } -template<> +template <> inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); @@ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #endif } -template<> +template <> inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); } -template<> +template <> inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { bf16_4_t c; c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); @@ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { return c; } -template<> +template <> inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { __nv_bfloat162 s = bf162bf162(a); bf16_4_t c; @@ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { return c; } -template<> +template <> inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { bf16_8_t c; c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); @@ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { return c; } -template<> +template <> inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { __nv_bfloat162 s = bf162bf162(a); bf16_8_t c; @@ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { return c; } -template<> +template <> inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { float fa = __bfloat162float(a); float fb = __bfloat162float(b); return fa * fb; } -template<> +template <> inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { float2 fa = bf1622float2(a); float2 fb = bf1622float2(b); return mul(fa, fb); } -template<> +template <> inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { return mul(bf162bf162(a), b); } -template<> +template <> inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { Float4_ fc; fc.x = mul(a.x, b.x); @@ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { return fc; } -template<> +template <> inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { __nv_bfloat162 s = bf162bf162(a); Float4_ fc; @@ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { Float8_ fc; fc.x = mul(a.x, b.x); @@ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { __nv_bfloat162 s = bf162bf162(a); Float8_ fc; @@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { } // Vector fused multiply-add. -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, + __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else @@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf #endif } -inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, + __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else @@ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { } // Vector sum. -template<> +template <> inline __device__ float sum(__nv_bfloat16 v) { return __bfloat162float(v); } -template<> +template <> inline __device__ float sum(__nv_bfloat162 v) { float2 vf = bf1622float2(v); return vf.x + vf.y; } -template<> +template <> inline __device__ float sum(bf16_4_t v) { return sum(v.x) + sum(v.y); } -template<> +template <> inline __device__ float sum(bf16_8_t v) { return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); } @@ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) { #endif } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index d3271e69cd69d..3a1815f0ed4fc 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -30,37 +32,37 @@ namespace vllm { // FP16 vector types for Q, K, V. -template<> +template <> struct Vec { using Type = uint16_t; }; -template<> +template <> struct Vec { using Type = uint32_t; }; -template<> +template <> struct Vec { using Type = uint2; }; -template<> +template <> struct Vec { using Type = uint4; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec { using Type = float; }; -template<> +template <> struct FloatVec { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = Float4_; }; -template<> +template <> struct FloatVec { using Type = Float8_; }; @@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) { return b; #else union { - uint32_t u32; - uint16_t u16[2]; + uint32_t u32; + uint16_t u16[2]; } tmp; tmp.u16[0] = a; tmp.u16[1] = a; @@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) { } tmp; #ifndef USE_ROCM #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" + : "=r"(tmp.u32) + : "f"(f.y), "f"(f.x)); #else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); #endif #else tmp.u16[0] = float_to_half(f.x); @@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { } // Vector multiplication. -template<> +template <> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { uint16_t c; #ifndef USE_ROCM @@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) { return c; } -template<> +template <> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { uint32_t c; #ifndef USE_ROCM @@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) { return c; } -template<> +template <> inline __device__ uint32_t mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } -template<> +template <> inline __device__ uint2 mul(uint2 a, uint2 b) { uint2 c; c.x = mul(a.x, b.x); @@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) { return c; } -template<> +template <> inline __device__ uint2 mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); uint2 c; @@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) { return c; } -template<> +template <> inline __device__ uint4 mul(uint4 a, uint4 b) { uint4 c; c.x = mul(a.x, b.x); @@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) { return c; } -template<> +template <> inline __device__ uint4 mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); uint4 c; @@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) { return c; } -template<> +template <> inline __device__ float mul(uint16_t a, uint16_t b) { float fa = half_to_float(a); float fb = half_to_float(b); return fa * fb; } -template<> +template <> inline __device__ float2 mul(uint32_t a, uint32_t b) { float2 fa = half2_to_float2(a); float2 fb = half2_to_float2(b); return mul(fa, fb); } -template<> +template <> inline __device__ float2 mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } -template<> +template <> inline __device__ Float4_ mul(uint2 a, uint2 b) { Float4_ fc; fc.x = mul(a.x, b.x); @@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) { return fc; } -template<> +template <> inline __device__ Float4_ mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); Float4_ fc; @@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(uint4 a, uint4 b) { Float8_ fc; fc.x = mul(a.x, b.x); @@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); Float8_ fc; @@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; #ifndef USE_ROCM - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); #else - asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" + : "=v"(d) + : "v"(a), "v"(b), "v"(c)); #endif return d; } @@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { } // Vector sum. -template<> +template <> inline __device__ float sum(uint16_t v) { return half_to_float(v); } -template<> +template <> inline __device__ float sum(uint32_t v) { float2 tmp = half2_to_float2(v); return tmp.x + tmp.y; } -template<> +template <> inline __device__ float sum(uint2 v) { uint32_t c = add(v.x, v.y); return sum(c); } -template<> +template <> inline __device__ float sum(uint4 v) { uint32_t c = add(v.x, v.y); c = add(c, v.z); @@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) { } // From float16 to float32. -inline __device__ float to_float(uint16_t u) { - return half_to_float(u); -} +inline __device__ float to_float(uint16_t u) { return half_to_float(u); } -inline __device__ float2 to_float(uint32_t u) { - return half2_to_float2(u); -} +inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); } inline __device__ Float4_ to_float(uint2 u) { Float4_ tmp; @@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) { } // Zero-out a variable. -inline __device__ void zero(uint16_t& dst) { - dst = uint16_t(0); -} +inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh index b200d2d226eb0..7c6a686db3ba9 100644 --- a/csrc/attention/dtype_float32.cuh +++ b/csrc/attention/dtype_float32.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -38,37 +40,35 @@ struct Float8_ { }; // FP32 vector types for Q, K, V. -template<> +template <> struct Vec { using Type = float; }; -template<> +template <> struct Vec { using Type = float2; }; -template<> +template <> struct Vec { using Type = float4; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec { using Type = float; }; -template<> +template <> struct FloatVec { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = float4; }; // Vector addition. -inline __device__ float add(float a, float b) { - return a + b; -} +inline __device__ float add(float a, float b) { return a + b; } inline __device__ float2 add(float2 a, float2 b) { float2 c; @@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) { } // Vector multiplication. -template<> +template <> inline __device__ float mul(float a, float b) { return a * b; } -template<> +template <> inline __device__ float2 mul(float2 a, float2 b) { float2 c; c.x = a.x * b.x; @@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) { return c; } -template<> +template <> inline __device__ float2 mul(float a, float2 b) { float2 c; c.x = a * b.x; @@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) { return c; } -template<> +template <> inline __device__ float4 mul(float4 a, float4 b) { float4 c; c.x = a.x * b.x; @@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) { return c; } -template<> +template <> inline __device__ float4 mul(float a, float4 b) { float4 c; c.x = a * b.x; @@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) { } // Vector fused multiply-add. -inline __device__ float fma(float a, float b, float c) { - return a * b + c; -} +inline __device__ float fma(float a, float b, float c) { return a * b + c; } inline __device__ float2 fma(float2 a, float2 b, float2 c) { float2 d; @@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { } // Vector sum. -template<> +template <> inline __device__ float sum(float v) { return v; } -template<> +template <> inline __device__ float sum(float2 v) { return v.x + v.y; } -template<> +template <> inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } -template<> +template <> inline __device__ float sum(Float4_ v) { return v.x.x + v.x.y + v.y.x + v.y.y; } -template<> +template <> inline __device__ float sum(Float8_ v) { return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; } // Vector dot product. -inline __device__ float dot(float a, float b) { - return a * b; -} +inline __device__ float dot(float a, float b) { return a * b; } inline __device__ float dot(float2 a, float2 b) { float2 c = mul(a, b); @@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) { } // From float to float. -inline __device__ void from_float(float& dst, float src) { - dst = src; -} +inline __device__ void from_float(float& dst, float src) { dst = src; } -inline __device__ void from_float(float2& dst, float2 src) { - dst = src; -} +inline __device__ void from_float(float2& dst, float2 src) { dst = src; } -inline __device__ void from_float(float4& dst, float4 src) { - dst = src; -} +inline __device__ void from_float(float4& dst, float4 src) { dst = src; } // From float to float. -inline __device__ float to_float(float u) { - return u; -} +inline __device__ float to_float(float u) { return u; } -inline __device__ float2 to_float(float2 u) { - return u; -} +inline __device__ float2 to_float(float2 u) { return u; } -inline __device__ float4 to_float(float4 u) { - return u; -} +inline __device__ float4 to_float(float4 u) { return u; } -inline __device__ Float4_ to_float(Float4_ u) { - return u; -} +inline __device__ Float4_ to_float(Float4_ u) { return u; } -inline __device__ Float8_ to_float(Float8_ u) { - return u; -} +inline __device__ Float8_ to_float(Float8_ u) { return u; } // Zero-out a variable. -inline __device__ void zero(float& dst) { - dst = 0.f; -} +inline __device__ void zero(float& dst) { dst = 0.f; } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh index 2b32ce372a64f..e714e321b0beb 100644 --- a/csrc/attention/dtype_fp8.cuh +++ b/csrc/attention/dtype_fp8.cuh @@ -4,38 +4,38 @@ #include #ifdef ENABLE_FP8 -#ifndef USE_ROCM -#include -#endif // USE_ROCM -#endif // ENABLE_FP8 + #ifndef USE_ROCM + #include + #endif // USE_ROCM +#endif // ENABLE_FP8 namespace vllm { enum class Fp8KVCacheDataType { - kAuto = 0, - kFp8E4M3 = 1, - kFp8E5M2 = 2, + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, }; // fp8 vector types for quantization of kv cache -template<> +template <> struct Vec { - using Type = uint8_t; + using Type = uint8_t; }; -template<> +template <> struct Vec { - using Type = uint16_t; + using Type = uint16_t; }; -template<> +template <> struct Vec { - using Type = uint32_t; + using Type = uint32_t; }; -template<> +template <> struct Vec { - using Type = uint2; + using Type = uint2; }; -} // namespace vllm +} // namespace vllm diff --git a/csrc/cache.h b/csrc/cache.h index 8c176c452425e..435ae3e57f555 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -5,36 +5,24 @@ #include #include -void swap_blocks( - torch::Tensor& src, - torch::Tensor& dst, - const torch::Tensor& block_mapping); +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping); -void copy_blocks( - std::vector& key_caches, - std::vector& value_caches, - const torch::Tensor& block_mapping); +void copy_blocks(std::vector& key_caches, + std::vector& value_caches, + const torch::Tensor& block_mapping); -void reshape_and_cache( - torch::Tensor& key, - torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - const float kv_scale); +void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, const float kv_scale); -void reshape_and_cache_flash( - torch::Tensor& key, - torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype); +void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype); // Just for unittest -void convert_fp8( - torch::Tensor& dst_cache, - torch::Tensor& src_cache, - const float scale, - const std::string& kv_cache_dtype); +void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, + const float scale, const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index e5b74da6ad068..d924ac39b89ca 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -6,9 +6,9 @@ #include "dispatch_utils.h" #ifdef USE_ROCM -#include "quantization/fp8/amd/quant_utils.cuh" + #include "quantization/fp8/amd/quant_utils.cuh" #else -#include "quantization/fp8/nvidia/quant_utils.cuh" + #include "quantization/fp8/nvidia/quant_utils.cuh" #endif #include @@ -18,20 +18,17 @@ #ifdef USE_ROCM #include - typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat16 __nv_bfloat16; #endif -void swap_blocks( - torch::Tensor& src, - torch::Tensor& dst, - const torch::Tensor& block_mapping) { +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); cudaMemcpyKind memcpy_type; if (src_device.is_cuda() && dst_device.is_cuda()) { - TORCH_CHECK( - src_device.index() == dst_device.index(), - "src and dst must be on the same GPU"); + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); memcpy_type = cudaMemcpyDeviceToDevice; } else if (src_device.is_cuda() && dst_device.is_cpu()) { memcpy_type = cudaMemcpyDeviceToHost; @@ -41,16 +38,17 @@ void swap_blocks( TORCH_CHECK(false, "Invalid device combination"); } - // NOTE(youkaichao): keep in mind that `block_mapping` should be + // NOTE(youkaichao): keep in mind that `block_mapping` should be // a cpu tensor, otherwise every `item` call will require a gpu-cpu // synchronization. TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); - char *src_ptr = static_cast(src.data_ptr()); - char *dst_ptr = static_cast(dst.data_ptr()); + char* src_ptr = static_cast(src.data_ptr()); + char* dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); - const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); + const at::cuda::OptionalCUDAGuard device_guard( + src_device.is_cuda() ? src_device : dst_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // NOTE(woosuk): This can be slow if the number of blocks is large. const int64_t num_blocks = block_mapping.size(0); @@ -59,29 +57,25 @@ void swap_blocks( int64_t dst_block_number = block_mapping[i][1].item(); int64_t src_offset = src_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes; - cudaMemcpyAsync( - dst_ptr + dst_offset, - src_ptr + src_offset, - block_size_in_bytes, - memcpy_type, - stream); + cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset, + block_size_in_bytes, memcpy_type, stream); } } namespace vllm { // Grid: (num_layers, num_pairs) -template -__global__ void copy_blocks_kernel( - int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int64_t* __restrict__ block_mapping, - const int numel_per_block) { +template +__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, + int64_t* value_cache_ptrs, + const int64_t* __restrict__ block_mapping, + const int numel_per_block) { const int layer_idx = blockIdx.x; const int pair_idx = blockIdx.y; scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); - scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]); + scalar_t* value_cache = + reinterpret_cast(value_cache_ptrs[layer_idx]); int64_t src_block_number = block_mapping[2 * pair_idx]; int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; @@ -99,12 +93,11 @@ __global__ void copy_blocks_kernel( } } -} // namespace vllm +} // namespace vllm -void copy_blocks( - std::vector& key_caches, - std::vector& value_caches, - const torch::Tensor& block_mapping) { +void copy_blocks(std::vector& key_caches, + std::vector& value_caches, + const torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { @@ -118,8 +111,10 @@ void copy_blocks( int64_t key_cache_ptrs[num_layers]; int64_t value_cache_ptrs[num_layers]; for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { - key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr()); - value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr()); + key_cache_ptrs[layer_idx] = + reinterpret_cast(key_caches[layer_idx].data_ptr()); + value_cache_ptrs[layer_idx] = + reinterpret_cast(value_caches[layer_idx].data_ptr()); } // block_mapping is a 2D tensor with shape (num_pairs, 2). @@ -127,10 +122,12 @@ void copy_blocks( // Move the data structures to the GPU. // NOTE: This synchronizes the CPU and GPU. - torch::Tensor key_cache_ptrs_tensor = torch::from_blob( - key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); - torch::Tensor value_cache_ptrs_tensor = torch::from_blob( - value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); + torch::Tensor key_cache_ptrs_tensor = + torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64) + .to(cache_device); + torch::Tensor value_cache_ptrs_tensor = + torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64) + .to(cache_device); // Launch the kernel. const int numel_per_block = key_caches[0][0].numel(); @@ -139,31 +136,28 @@ void copy_blocks( const at::cuda::OptionalCUDAGuard device_guard(cache_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( - key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { - vllm::copy_blocks_kernel<<>>( - key_cache_ptrs_tensor.data_ptr(), - value_cache_ptrs_tensor.data_ptr(), - block_mapping.data_ptr(), - numel_per_block); - })); + key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { + vllm::copy_blocks_kernel<<>>( + key_cache_ptrs_tensor.data_ptr(), + value_cache_ptrs_tensor.data_ptr(), + block_mapping.data_ptr(), numel_per_block); + })); } namespace vllm { -template +template __global__ void reshape_and_cache_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const int x, - const float kv_scale) { + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, + // block_size, x] + cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, + // block_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, const int x, + const float kv_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -184,40 +178,39 @@ __global__ void reshape_and_cache_kernel( const int x_idx = head_offset / x; const int x_offset = head_offset % x; - const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; + const int64_t tgt_key_idx = + block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + + block_offset * x + x_offset; + const int64_t tgt_value_idx = + block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + head_offset * block_size + + block_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { key_cache[tgt_key_idx] = tgt_key; value_cache[tgt_value_idx] = tgt_value; } else { - key_cache[tgt_key_idx] = fp8::scaled_convert(tgt_key, kv_scale); - value_cache[tgt_value_idx] = fp8::scaled_convert(tgt_value, kv_scale); + key_cache[tgt_key_idx] = + fp8::scaled_convert(tgt_key, kv_scale); + value_cache[tgt_value_idx] = + fp8::scaled_convert(tgt_value, kv_scale); } } } -template +template __global__ void reshape_and_cache_flash_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] - scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int block_stride, - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size) { + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, + // head_size] + scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, + // head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, const int key_stride, const int value_stride, + const int num_heads, const int head_size, const int block_size) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -232,43 +225,37 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t src_value_idx = token_idx * value_stride + i; const int head_idx = i / head_size; const int head_offset = i % head_size; - const int64_t tgt_value_idx = block_idx * block_stride - + block_offset * num_heads * head_size - + head_idx * head_size - + head_offset; + const int64_t tgt_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + head_offset; k_cache[tgt_value_idx] = key[src_key_idx]; v_cache[tgt_value_idx] = value[src_value_idx]; } } -} // namespace vllm +} // namespace vllm // KV_T is the stored data type of kv-cache. // CACHE_T is the data type of key and value tensors. // KV_DTYPE is the real data type of kv-cache. -#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ - vllm::reshape_and_cache_kernel<<>>( \ - reinterpret_cast(key.data_ptr()), \ - reinterpret_cast(value.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - slot_mapping.data_ptr(), \ - key_stride, \ - value_stride, \ - num_heads, \ - head_size, \ - block_size, \ - x, \ - kv_scale); +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), key_stride, value_stride, \ + num_heads, head_size, block_size, x, kv_scale); void reshape_and_cache( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, - const float kv_scale) -{ + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype, const float kv_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); @@ -283,17 +270,17 @@ void reshape_and_cache( const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE) + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE) } void reshape_and_cache_flash( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype) -{ + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype) { // FIXME: only support auto datatype, does not support fp8 if (kv_cache_dtype != "auto") { TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); @@ -313,62 +300,47 @@ void reshape_and_cache_flash( const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - key.scalar_type(), - "reshape_and_cache_flash", - [&] { - vllm::reshape_and_cache_flash_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - k_cache.data_ptr(), - v_cache.data_ptr(), - slot_mapping.data_ptr(), - block_stride, - key_stride, - value_stride, - num_heads, - head_size, - block_size); - }); + key.scalar_type(), "reshape_and_cache_flash", [&] { + vllm::reshape_and_cache_flash_kernel + <<>>( + key.data_ptr(), value.data_ptr(), + k_cache.data_ptr(), v_cache.data_ptr(), + slot_mapping.data_ptr(), block_stride, key_stride, + value_stride, num_heads, head_size, block_size); + }); } namespace vllm { -template -__global__ void convert_fp8_kernel( - const Tin* __restrict__ src_cache, - Tout* __restrict__ dst_cache, - const float kv_scale, - const int64_t block_stride) { +template +__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, + Tout* __restrict__ dst_cache, + const float kv_scale, + const int64_t block_stride) { const int64_t block_idx = blockIdx.x; for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { int64_t idx = block_idx * block_stride + i; - dst_cache[idx] = fp8::scaled_convert(src_cache[idx], kv_scale); + dst_cache[idx] = + fp8::scaled_convert(src_cache[idx], kv_scale); } } -} // namespace vllm +} // namespace vllm -#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ - vllm::convert_fp8_kernel<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst_cache.data_ptr()), \ - kv_scale, \ - block_stride); +#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ + vllm::convert_fp8_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), kv_scale, block_stride); // Only for testing. -void convert_fp8( - torch::Tensor& dst_cache, - torch::Tensor& src_cache, - const float kv_scale, - const std::string& kv_cache_dtype) -{ +void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, + const float kv_scale, const std::string& kv_cache_dtype) { torch::Device src_device = src_cache.device(); torch::Device dst_device = dst_cache.device(); TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") - TORCH_CHECK( - src_device.index() == dst_device.index(), - "src and dst must be on the same GPU"); + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); at::cuda::OptionalCUDAGuard device_guard(src_device); int64_t num_blocks = src_cache.size(0); @@ -398,13 +370,15 @@ void convert_fp8( } else if (src_cache.dtype() == at::ScalarType::Half) { CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); } else if (src_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3); + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, + vllm::Fp8KVCacheDataType::kFp8E4M3); } else if (dst_cache.dtype() == at::ScalarType::Float) { CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); } else if (dst_cache.dtype() == at::ScalarType::Half) { CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); } } else { TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); diff --git a/csrc/cpu/activation.cpp b/csrc/cpu/activation.cpp index 1bd24eb79d129..becd2ac42f17a 100644 --- a/csrc/cpu/activation.cpp +++ b/csrc/cpu/activation.cpp @@ -1,10 +1,10 @@ #include "cpu_types.hpp" namespace { -template -void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, - scalar_t *__restrict__ output) { +void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input, + scalar_t* __restrict__ output) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); @@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, } } -FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 zeros(0.0); const vec_op::FP32Vec8 ones(1.0); return x / (ones + (zeros - x).exp()); } -FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w2(0.044715f); @@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { return w3 * x * (ones + t); } -FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w2(0.044715f); @@ -59,14 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { return w3 * x * (ones + t); } -FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(M_SQRT1_2); const vec_op::FP32Vec8 w2(0.5); return x * w2 * (ones + (x * w1).er()); } -FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); const vec_op::FP32Vec8 w2(0.5); @@ -75,40 +75,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); return x * w2 * (ones + inner.tanh()); } -}; // namespace +}; // namespace -void silu_and_mul(torch::Tensor &out, torch::Tensor &input) { +void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "silu_and_mul_impl", [&] { - CPU_KERNEL_GUARD_IN(silu_and_mul_impl) - activation_kernel(num_tokens, d, - input.data_ptr(), - out.data_ptr()); - CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(silu_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) + }); } -void gelu_and_mul(torch::Tensor &out, // [..., d] - torch::Tensor &input) // [..., 2 * d] +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "gelu_and_mul_impl", [&] { - CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) - activation_kernel(num_tokens, d, - input.data_ptr(), - out.data_ptr()); - CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) + }); } -void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] - torch::Tensor &input) // [..., 2 * d] +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; @@ -123,7 +119,7 @@ void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] }); } -void gelu_new(torch::Tensor &out, torch::Tensor &input) { +void gelu_new(torch::Tensor& out, torch::Tensor& input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1); @@ -135,7 +131,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) { }); } -void gelu_fast(torch::Tensor &out, torch::Tensor &input) { +void gelu_fast(torch::Tensor& out, torch::Tensor& input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1); diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index c1d765be05598..54df69b7379d6 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -2,7 +2,8 @@ namespace { -template struct KernelVecType { +template +struct KernelVecType { using q_load_vec_type = void; using q_vec_type = void; using k_load_vec_type = void; @@ -11,7 +12,8 @@ template struct KernelVecType { using v_load_vec_type = void; }; -template <> struct KernelVecType { +template <> +struct KernelVecType { using q_load_vec_type = vec_op::FP32Vec4; using q_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::FP32Vec16; @@ -21,7 +23,8 @@ template <> struct KernelVecType { }; #ifdef __AVX512BF16__ -template <> struct KernelVecType { +template <> +struct KernelVecType { using q_load_vec_type = vec_op::BF16Vec8; using q_vec_type = vec_op::BF16Vec32; using k_load_vec_type = vec_op::BF16Vec32; @@ -30,7 +33,8 @@ template <> struct KernelVecType { using v_load_vec_type = vec_op::BF16Vec16; }; #else -template <> struct KernelVecType { +template <> +struct KernelVecType { using q_load_vec_type = vec_op::BF16Vec8; using q_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::BF16Vec16; @@ -41,7 +45,7 @@ template <> struct KernelVecType { #endif template -FORCE_INLINE std::pair reduceSoftmax(T *data, const int size, +FORCE_INLINE std::pair reduceSoftmax(T* data, const int size, const int capacity) { T max = data[0]; for (int i = 1; i < size; ++i) { @@ -67,10 +71,11 @@ FORCE_INLINE std::pair reduceSoftmax(T *data, const int size, } template -FORCE_INLINE std::pair -reduceSoftmaxAlibi(T *data, const int size, const int capacity, - const float alibi_slope, const int start_index, - const int seq_len) { +FORCE_INLINE std::pair reduceSoftmaxAlibi(T* data, const int size, + const int capacity, + const float alibi_slope, + const int start_index, + const int seq_len) { data[0] += alibi_slope * (start_index - seq_len + 1); T max = data[0]; for (int i = 1; i < size; ++i) { @@ -98,7 +103,7 @@ reduceSoftmaxAlibi(T *data, const int size, const int capacity, } template -FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data, +FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data, const int size) { T max = max_data[0]; for (int i = 1; i < size; ++i) { @@ -132,9 +137,9 @@ struct reduceQKBlockKernel { static_assert(k_load_vec_type::get_elem_num() % x == 0); static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); - FORCE_INLINE static void call(const scalar_t *__restrict__ q, - const scalar_t *__restrict__ k_block, - float *__restrict__ logits, float scale, + FORCE_INLINE static void call(const scalar_t* __restrict__ q, + const scalar_t* __restrict__ k_block, + float* __restrict__ logits, float scale, const int token_num) { const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; @@ -196,8 +201,8 @@ struct reduceQKBlockKernel { template -FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, - acc_t &&acc) { +FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block, + acc_t&& acc) { using v_load_vec_type = typename KernelVecType::v_load_vec_type; constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); static_assert(BLOCK_SIZE == ELEM_NUM); @@ -209,27 +214,27 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; }); } -}; // namespace +}; // namespace // Paged attention v1 namespace { template struct paged_attention_v1_impl { - static void - call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + static void call( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] - const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, // head_size, block_size] - const int num_kv_heads, const float scale, - const int - *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float *__restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const int num_seqs, const int num_heads) { + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, + // max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads) { constexpr int x = 16 / sizeof(scalar_t); const int num_queries_per_kv = num_heads / num_kv_heads; @@ -243,32 +248,31 @@ struct paged_attention_v1_impl { size_t logits_bytes = parallel_work_item_num * max_seq_len_padded * sizeof(float); - float *logits = (float *)std::aligned_alloc( - 64, logits_bytes); // Cacheline alignment for each context token. - // [parallel_work_item_num, max_seq_len_padded] + float* logits = (float*)std::aligned_alloc( + 64, logits_bytes); // Cacheline alignment for each context token. + // [parallel_work_item_num, max_seq_len_padded] #pragma omp parallel for collapse(2) schedule(dynamic, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { int seq_len = seq_lens[seq_idx]; - const int *seq_block_table = + const int* seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx; const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t *__restrict__ q_vec_ptr = + const scalar_t* __restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; - const int last_block_token_num = - seq_len - (block_num - 1) * BLOCK_SIZE; - float *__restrict__ thread_block_logits = + const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE; + float* __restrict__ thread_block_logits = logits + omp_get_thread_num() * max_seq_len_padded; // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t *__restrict__ k_block_cache_ptr = + const scalar_t* __restrict__ k_block_cache_ptr = k_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride; - float *__restrict__ head_block_logits = + float* __restrict__ head_block_logits = thread_block_logits + block_idx * BLOCK_SIZE; reduceQKBlockKernel::call( @@ -282,8 +286,7 @@ struct paged_attention_v1_impl { block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, seq_len); } else { - reduceSoftmax(thread_block_logits, seq_len, - block_num * BLOCK_SIZE); + reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); } // Compute value @@ -293,14 +296,14 @@ struct paged_attention_v1_impl { for (int head_part_idx = 0; head_part_idx < head_partition_num; ++head_part_idx) { vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t *__restrict__ out_ptr = + scalar_t* __restrict__ out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_part_idx * head_elem_num_per_partition; for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const float *__restrict__ prob_vec_ptr = + const float* __restrict__ prob_vec_ptr = thread_block_logits + block_idx * BLOCK_SIZE; - const scalar_t *__restrict__ v_block_cache_ptr = + const scalar_t* __restrict__ v_block_cache_ptr = v_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -311,7 +314,7 @@ struct paged_attention_v1_impl { if (block_idx != block_num - 1) { const int64_t next_physical_block_idx = seq_block_table[block_idx + 1]; - const scalar_t *__restrict__ next_v_block_cache_ptr = + const scalar_t* __restrict__ next_v_block_cache_ptr = v_cache + next_physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -340,16 +343,16 @@ struct paged_attention_v1_impl { #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ paged_attention_v1_impl::call( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ num_heads); template void paged_attention_v1_impl_launcher( - torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, - torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &seq_lens, - int max_seq_len, const c10::optional &alibi_slopes) { + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -359,67 +362,66 @@ void paged_attention_v1_impl_launcher( int kv_head_stride = key_cache.stride(1); // NOTE: alibi_slopes is optional. - const float *alibi_slopes_ptr = + const float* alibi_slopes_ptr = alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) + ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - T *out_ptr = reinterpret_cast(out.data_ptr()); - T *query_ptr = reinterpret_cast(query.data_ptr()); - T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int *block_tables_ptr = block_tables.data_ptr(); - int *seq_lens_ptr = seq_lens.data_ptr(); + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { - case 64: - LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 256: - LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; + case 64: + LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 256: + LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; } } -#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_impl_launcher( \ - out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ +#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_impl_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ seq_lens, max_seq_len, alibi_slopes); -#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V1_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V1_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } -} // namespace +} // namespace -void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, - torch::Tensor &key_cache, torch::Tensor &value_cache, +void paged_attention_v1(torch::Tensor& out, torch::Tensor& query, + torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, - torch::Tensor &seq_lens, int block_size, - int max_seq_len, - const c10::optional &alibi_slopes, - const std::string &kv_cache_dtype, float kv_scale) { + torch::Tensor& block_tables, torch::Tensor& seq_lens, + int block_size, int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", [&] { @@ -434,23 +436,24 @@ namespace { template struct paged_attention_v2_impl { static void call( - scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] - float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float - *__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const int num_kv_heads, const float scale, - const int - *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ seq_lens, // [num_seqs] + const int* __restrict__ block_tables, // [num_seqs, + // max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, - const float *__restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, const int num_seqs, const int num_heads, const int max_num_partitions) { constexpr int x = 16 / sizeof(scalar_t); @@ -468,8 +471,7 @@ struct paged_attention_v2_impl { const int seq_len = seq_lens[seq_idx]; const int start_token_idx = partition_idx * PARTITION_SIZE; - if (start_token_idx >= seq_len) - continue; + if (start_token_idx >= seq_len) continue; const int partition_num = (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; @@ -477,15 +479,14 @@ struct paged_attention_v2_impl { const int token_num = (std::min(seq_len, start_token_idx + PARTITION_SIZE) - start_token_idx); - const int block_num = - (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; const int last_block_token_num = token_num - (block_num - 1) * BLOCK_SIZE; - const int *seq_block_table = block_tables + + const int* seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx + start_token_idx / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t *__restrict__ q_vec_ptr = + const scalar_t* __restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; @@ -493,10 +494,10 @@ struct paged_attention_v2_impl { // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t *__restrict__ k_block_cache_ptr = + const scalar_t* __restrict__ k_block_cache_ptr = k_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride; - float *__restrict__ head_block_logits = + float* __restrict__ head_block_logits = logits + block_idx * BLOCK_SIZE; reduceQKBlockKernel::call( @@ -510,13 +511,13 @@ struct paged_attention_v2_impl { logits, token_num, block_num * BLOCK_SIZE, alibi_slopes[head_idx], start_token_idx, seq_len); } else { - max_and_sum = reduceSoftmax(logits, token_num, - block_num * BLOCK_SIZE); + max_and_sum = + reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); } - auto &&[max_logit, exp_sum] = max_and_sum; + auto&& [max_logit, exp_sum] = max_and_sum; - scalar_t *__restrict__ output_buffer = nullptr; + scalar_t* __restrict__ output_buffer = nullptr; if (!no_reduce) { auto idx = seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions + partition_idx; @@ -538,13 +539,13 @@ struct paged_attention_v2_impl { for (int head_part_idx = 0; head_part_idx < head_partition_num; ++head_part_idx) { vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t *__restrict__ out_ptr = + scalar_t* __restrict__ out_ptr = output_buffer + head_part_idx * head_elem_num_per_partition; for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const float *__restrict__ prob_vec_ptr = + const float* __restrict__ prob_vec_ptr = logits + block_idx * BLOCK_SIZE; - const scalar_t *__restrict__ v_block_cache_ptr = + const scalar_t* __restrict__ v_block_cache_ptr = v_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -555,7 +556,7 @@ struct paged_attention_v2_impl { if (block_idx != block_num - 1) { const int64_t next_physical_block_idx = seq_block_table[block_idx + 1]; - const scalar_t *__restrict__ next_v_block_cache_ptr = + const scalar_t* __restrict__ next_v_block_cache_ptr = v_cache + next_physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -587,8 +588,7 @@ struct paged_attention_v2_impl { const int partition_num = (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - if (partition_num == 1) - continue; + if (partition_num == 1) continue; reducePartitonSoftmax( max_logits + seq_idx * num_heads * max_num_partitions + @@ -603,11 +603,11 @@ struct paged_attention_v2_impl { using v_load_vec_type = typename KernelVecType::v_load_vec_type; static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); constexpr int head_elem_num_per_group = - 16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE - // didn't align with 64 bytes + 16; // Note: didn't align with the cacheline size, due to some + // HEAD_SIZE didn't align with 64 bytes static_assert(HEAD_SIZE % head_elem_num_per_group == 0); constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; - const float *__restrict__ rescale_factors = exp_sums; + const float* __restrict__ rescale_factors = exp_sums; #pragma omp parallel for collapse(3) schedule(static, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { @@ -616,17 +616,16 @@ struct paged_attention_v2_impl { const int partition_num = (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - if (partition_num == 1) - continue; + if (partition_num == 1) continue; - const float *__restrict__ seq_head_rescale_factors = + const float* __restrict__ seq_head_rescale_factors = rescale_factors + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; - const scalar_t *__restrict__ seq_head_tmp_out = + const scalar_t* __restrict__ seq_head_tmp_out = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE + group_idx * head_elem_num_per_group; - scalar_t *__restrict__ seq_head_output = + scalar_t* __restrict__ seq_head_output = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + group_idx * head_elem_num_per_group; @@ -645,21 +644,21 @@ struct paged_attention_v2_impl { } }; -#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_v2_impl::call( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ - key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, num_seqs, num_heads, \ +#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v2_impl::call( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ + key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, num_seqs, num_heads, \ max_num_partitions); template void paged_attention_v2_impl_launcher( - torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, - torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, - torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size, - int max_seq_len, const c10::optional &alibi_slopes) { + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -670,72 +669,72 @@ void paged_attention_v2_impl_launcher( int max_num_partitions = exp_sums.size(-1); // NOTE: alibi_slopes is optional. - const float *alibi_slopes_ptr = + const float* alibi_slopes_ptr = alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) + ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - T *out_ptr = reinterpret_cast(out.data_ptr()); - float *exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float *max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T *tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T *query_ptr = reinterpret_cast(query.data_ptr()); - T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int *block_tables_ptr = block_tables.data_ptr(); - int *seq_lens_ptr = seq_lens.data_ptr(); + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { - case 64: - LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 256: - LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; + case 64: + LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 256: + LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; } } -#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_impl_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, seq_lens, block_size, \ - max_seq_len, alibi_slopes); - -#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V2_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_impl_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \ + alibi_slopes); + +#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V2_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } -} // namespace - -void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, - torch::Tensor &max_logits, torch::Tensor &tmp_out, - torch::Tensor &query, torch::Tensor &key_cache, - torch::Tensor &value_cache, int num_kv_heads, - float scale, torch::Tensor &block_tables, - torch::Tensor &seq_lens, int block_size, +} // namespace + +void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums, + torch::Tensor& max_logits, torch::Tensor& tmp_out, + torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, + float scale, torch::Tensor& block_tables, + torch::Tensor& seq_lens, int block_size, int max_seq_len, - const c10::optional &alibi_slopes, - const std::string &kv_cache_dtype, float kv_scale) { + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", [&] { diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 26e81685d623e..2890ba6e2bb32 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -5,25 +5,26 @@ namespace { template -void copy_blocks_cpu_impl( - std::vector &key_caches, - std::vector &value_caches, - const torch::Tensor& mapping_pairs, - const int element_num_per_block, const int layer_num) { +void copy_blocks_cpu_impl(std::vector& key_caches, + std::vector& value_caches, + const torch::Tensor& mapping_pairs, + const int element_num_per_block, + const int layer_num) { const size_t pair_num = mapping_pairs.size(0); const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; #pragma omp parallel for collapse(2) for (int layer = 0; layer < layer_num; ++layer) { for (size_t pair = 0; pair < pair_num; ++pair) { - int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item(); + int64_t source_offset = + element_num_per_block * mapping_pairs[pair][0].item(); int64_t target_offset = element_num_per_block * mapping_pairs[pair][1].item(); - scalar_t *key_cache_ptr = key_caches[layer].data_ptr(); - scalar_t *source_ptr = key_cache_ptr + source_offset; - scalar_t *target_ptr = key_cache_ptr + target_offset; + scalar_t* key_cache_ptr = key_caches[layer].data_ptr(); + scalar_t* source_ptr = key_cache_ptr + source_offset; + scalar_t* target_ptr = key_cache_ptr + target_offset; std::memcpy(target_ptr, source_ptr, block_bytes); - scalar_t *value_cache_ptr = value_caches[layer].data_ptr(); + scalar_t* value_cache_ptr = value_caches[layer].data_ptr(); source_ptr = value_cache_ptr + source_offset; target_ptr = value_cache_ptr + target_offset; std::memcpy(target_ptr, source_ptr, block_bytes); @@ -33,9 +34,9 @@ void copy_blocks_cpu_impl( template void reshape_and_cache_cpu_impl( - const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, - scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, - const int64_t *__restrict__ slot_mapping, const int num_tokens, + const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, + scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, + const int64_t* __restrict__ slot_mapping, const int num_tokens, const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, const int x) { const int block_elem_num = num_heads * head_size * block_size; @@ -48,14 +49,14 @@ void reshape_and_cache_cpu_impl( int src_key_head_idx = token_idx * key_stride + head_idx * head_size; int src_value_head_idx = token_idx * value_stride + head_idx * head_size; - const scalar_t *src_key_head_ptr = key + src_key_head_idx; - const scalar_t *src_value_head_ptr = value + src_value_head_idx; + const scalar_t* src_key_head_ptr = key + src_key_head_idx; + const scalar_t* src_value_head_ptr = value + src_value_head_idx; const int64_t block_index = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; - scalar_t *target_key_head_ptr = key_cache + + scalar_t* target_key_head_ptr = key_cache + block_elem_num * block_index + head_idx * block_size * head_size; - scalar_t *target_value_head_ptr = value_cache + + scalar_t* target_value_head_ptr = value_cache + block_elem_num * block_index + head_idx * block_size * head_size; @@ -79,10 +80,10 @@ void reshape_and_cache_cpu_impl( } } } -}; // namespace +}; // namespace -void copy_blocks(std::vector &key_caches, - std::vector &value_caches, +void copy_blocks(std::vector& key_caches, + std::vector& value_caches, const torch::Tensor& block_mapping) { unsigned num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); @@ -100,10 +101,10 @@ void copy_blocks(std::vector &key_caches, }); } -void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, - torch::Tensor &key_cache, torch::Tensor &value_cache, - torch::Tensor &slot_mapping, - const std::string &kv_cache_dtype, float kv_scale) { +void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); int num_tokens = key.size(0); @@ -127,7 +128,7 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, }); } -void swap_blocks(torch::Tensor &src, torch::Tensor &dst, - const torch::Tensor&block_mapping) { +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") } diff --git a/csrc/cpu/layernorm.cpp b/csrc/cpu/layernorm.cpp index 467f0dc84982c..65d3ddcec5709 100644 --- a/csrc/cpu/layernorm.cpp +++ b/csrc/cpu/layernorm.cpp @@ -2,10 +2,10 @@ namespace { template -void rms_norm_impl(scalar_t *__restrict__ out, - const scalar_t *__restrict__ input, - const scalar_t *__restrict__ weight, const float epsilon, - const int num_tokens, const int hidden_size) { +void rms_norm_impl(scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ weight, const float epsilon, + const int num_tokens, const int hidden_size) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); @@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out, } template -void fused_add_rms_norm_impl(scalar_t *__restrict__ input, - scalar_t *__restrict__ residual, - const scalar_t *__restrict__ weight, - const float epsilon, const int num_tokens, - const int hidden_size) { +void fused_add_rms_norm_impl(scalar_t* __restrict__ input, + scalar_t* __restrict__ residual, + const scalar_t* __restrict__ weight, + const float epsilon, const int num_tokens, + const int hidden_size) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); @@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input, } } } -} // namespace +} // namespace -void rms_norm(torch::Tensor &out, torch::Tensor &input, - torch::Tensor &weight, float epsilon) { +void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, + float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] { CPU_KERNEL_GUARD_IN(rms_norm_impl) rms_norm_impl(out.data_ptr(), input.data_ptr(), - weight.data_ptr(), epsilon, num_tokens, - hidden_size); + weight.data_ptr(), epsilon, num_tokens, + hidden_size); CPU_KERNEL_GUARD_OUT(rms_norm_impl) }); } -void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, - torch::Tensor &weight, float epsilon) { +void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, + torch::Tensor& weight, float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index 5dc1bde45ac5f..73bf77e46f538 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -4,16 +4,16 @@ namespace { template void rotary_embedding_impl( - const int64_t - *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t - *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or - /// [num_tokens, num_heads, head_size] - scalar_t - *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or - // [num_tokens, num_kv_heads, head_size] - const scalar_t - *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, + /// head_size] or [num_tokens, num_heads, + /// head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int num_heads, const int num_kv_heads, const int head_size, const int num_tokens) { @@ -26,7 +26,7 @@ void rotary_embedding_impl( #pragma omp parallel for for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { int64_t pos = positions[token_idx]; - const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; for (int i = 0; i < num_heads; ++i) { const int head_idx = i; @@ -94,16 +94,16 @@ void rotary_embedding_impl( template void rotary_embedding_gptj_impl( - const int64_t - *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t - *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or - /// [num_tokens, num_heads, head_size] - scalar_t - *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or - // [num_tokens, num_kv_heads, head_size] - const scalar_t - *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, + /// head_size] or [num_tokens, num_heads, + /// head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int num_heads, const int num_kv_heads, const int head_size, const int num_tokens) { @@ -113,13 +113,13 @@ void rotary_embedding_gptj_impl( for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int i = 0; i < num_heads; ++i) { int64_t pos = positions[token_idx]; - const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; - const scalar_t *cos_cache_ptr = cache_ptr; - const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t* cos_cache_ptr = cache_ptr; + const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; const int head_idx = i; const int64_t token_head = token_idx * query_stride + head_idx * head_size; - scalar_t *head_query = token_head + query; + scalar_t* head_query = token_head + query; for (int j = 0; j < embed_dim; j += 1) { const int rot_offset = j; const int x_index = 2 * rot_offset; @@ -141,12 +141,12 @@ void rotary_embedding_gptj_impl( for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int i = 0; i < num_kv_heads; ++i) { int64_t pos = positions[token_idx]; - const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; - const scalar_t *cos_cache_ptr = cache_ptr; - const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t* cos_cache_ptr = cache_ptr; + const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; const int head_idx = i; const int64_t token_head = token_idx * key_stride + head_idx * head_size; - scalar_t *head_key = key + token_head; + scalar_t* head_key = key + token_head; for (int j = 0; j < embed_dim; j += 1) { const int rot_offset = j; const int x_index = 2 * rot_offset; @@ -164,11 +164,11 @@ void rotary_embedding_gptj_impl( } } } -}; // namespace +}; // namespace -void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, - torch::Tensor &key, int head_size, - torch::Tensor &cos_sin_cache, bool is_neox) { +void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int head_size, + torch::Tensor& cos_sin_cache, bool is_neox) { int num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp index bba044087f37c..63082393c8102 100644 --- a/csrc/cpu/pybind.cpp +++ b/csrc/cpu/pybind.cpp @@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); // Attention ops - ops.def( - "paged_attention_v1", - &paged_attention_v1, - "Compute the attention between an input query and the cached keys/values using PagedAttention."); - ops.def( - "paged_attention_v2", - &paged_attention_v2, - "PagedAttention V2."); + ops.def("paged_attention_v1", &paged_attention_v1, + "Compute the attention between an input query and the cached " + "keys/values using PagedAttention."); + ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); // Activation ops - ops.def( - "silu_and_mul", - &silu_and_mul, - "Activation function used in SwiGLU."); - ops.def( - "gelu_and_mul", - &gelu_and_mul, - "Activation function used in GeGLU with `none` approximation."); - ops.def( - "gelu_tanh_and_mul", - &gelu_tanh_and_mul, - "Activation function used in GeGLU with `tanh` approximation."); - ops.def( - "gelu_new", - &gelu_new, - "GELU implementation used in GPT-2."); - ops.def( - "gelu_fast", - &gelu_fast, - "Approximate GELU implementation."); + ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); + ops.def("gelu_and_mul", &gelu_and_mul, + "Activation function used in GeGLU with `none` approximation."); + ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, + "Activation function used in GeGLU with `tanh` approximation."); + ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); + ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); // Layernorm - ops.def( - "rms_norm", - &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); + ops.def("rms_norm", &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); - ops.def( - "fused_add_rms_norm", - &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); + ops.def("fused_add_rms_norm", &fused_add_rms_norm, + "In-place fused Add and RMS Normalization"); // Rotary embedding - ops.def( - "rotary_embedding", - &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + ops.def("rotary_embedding", &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def( - "swap_blocks", - &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - cache_ops.def( - "copy_blocks", - ©_blocks, - "Copy the cache blocks from src to dst"); - cache_ops.def( - "reshape_and_cache", - &reshape_and_cache, - "Reshape the key and value tensors and cache them"); + cache_ops.def("swap_blocks", &swap_blocks, + "Swap in (out) the cache blocks from src to dst"); + cache_ops.def("copy_blocks", ©_blocks, + "Copy the cache blocks from src to dst"); + cache_ops.def("reshape_and_cache", &reshape_and_cache, + "Reshape the key and value tensors and cache them"); } diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 1ebb2e74a82fc..5909e5eaf5e60 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -1,7 +1,7 @@ #pragma once #ifdef USE_ROCM -#include + #include #endif #ifndef USE_ROCM @@ -17,7 +17,8 @@ #endif #ifndef USE_ROCM - #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) #else #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) #endif @@ -29,7 +30,8 @@ #endif #ifndef USE_ROCM - #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta) + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ + __shfl_down_sync(uint32_t(-1), var, lane_delta) #else #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) #endif @@ -41,4 +43,3 @@ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) #endif - diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 1483484faeb4a..2ba49b339e148 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -2,9 +2,6 @@ #include -int get_device_attribute( - int attribute, - int device_id); +int get_device_attribute(int attribute, int device_id); -int get_max_shared_memory_per_block_device_attribute( - int device_id); +int get_max_shared_memory_per_block_device_attribute(int device_id); diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index 1a443ef3620cc..7d8e2e19720fa 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -2,34 +2,28 @@ #include #include #endif -int get_device_attribute( - int attribute, - int device_id) -{ - int device, value; - if (device_id < 0) { - cudaGetDevice(&device); - } - else { - device = device_id; - } - cudaDeviceGetAttribute(&value, static_cast(attribute), device); - return value; +int get_device_attribute(int attribute, int device_id) { + int device, value; + if (device_id < 0) { + cudaGetDevice(&device); + } else { + device = device_id; + } + cudaDeviceGetAttribute(&value, static_cast(attribute), + device); + return value; } - -int get_max_shared_memory_per_block_device_attribute( - int device_id) -{ -int attribute; -// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html -// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 +int get_max_shared_memory_per_block_device_attribute(int device_id) { + int attribute; + // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html + // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 #ifdef USE_ROCM - attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; + attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; #else - attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; + attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; #endif - return get_device_attribute(attribute, device_id); + return get_device_attribute(attribute, device_id); } diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 3906dcfc80dbf..0b1d95848525a 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -7,11 +7,11 @@ // fake pointer type using fptr_t = uint64_t; -static_assert(sizeof(void *) == sizeof(fptr_t)); +static_assert(sizeof(void*) == sizeof(fptr_t)); -fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, - const std::vector &handles, - const std::vector &offsets, int rank, +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, int rank, bool full_nvlink) { int world_size = offsets.size(); if (world_size > 8) @@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); } return (fptr_t) new vllm::CustomAllreduce( - reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), + reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); } @@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, * 5. A[None].expand(2, -1, -1, -1): Not OK * 6. A[:, 1:, 1:]: Not OK */ -bool _is_weak_contiguous(torch::Tensor &t) { +bool _is_weak_contiguous(torch::Tensor& t) { return t.is_contiguous() || (t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size()); } -bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, +bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, bool full_nvlink) { auto inp_size = inp.numel() * inp.element_size(); // custom allreduce requires input byte size to be multiples of 16 @@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, return false; } -void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, +void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, cudaStream_t stream) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); TORCH_CHECK(_is_weak_contiguous(out)); switch (out.scalar_type()) { case at::ScalarType::Float: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } case at::ScalarType::Half: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), - out.numel()); + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) case at::ScalarType::BFloat16: { fa->allreduce( - stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), out.numel()); + stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } #endif @@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, } } -void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = c10::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); @@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { _all_reduce(_fa, inp, out, stream); } -void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, - torch::Tensor &out) { +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, + torch::Tensor& out) { const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = c10::cuda::getCurrentCUDAStream().stream(); @@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, } void dispose(fptr_t _fa) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); delete fa; } int meta_size() { return sizeof(vllm::Signal); } -void register_buffer(fptr_t _fa, torch::Tensor &t, - const std::vector &handles, - const std::vector &offsets) { - auto fa = reinterpret_cast(_fa); +void register_buffer(fptr_t _fa, torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets) { + auto fa = reinterpret_cast(_fa); fa->register_buffer(handles, offsets, t.data_ptr()); } std::pair, std::vector> get_graph_buffer_ipc_meta( fptr_t _fa) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); return fa->get_graph_buffer_ipc_meta(); } -void register_graph_buffers(fptr_t _fa, const std::vector &handles, - const std::vector> &offsets) { - auto fa = reinterpret_cast(_fa); +void register_graph_buffers(fptr_t _fa, const std::vector& handles, + const std::vector>& offsets) { + auto fa = reinterpret_cast(_fa); fa->register_graph_buffers(handles, offsets); } diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 750e68d42f6c6..1ed49b8aa9cae 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -31,9 +31,9 @@ struct Signal { alignas(128) uint32_t end[kMaxBlocks][8]; }; -struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; +struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; -struct __align__(16) RankSignals { volatile Signal *signals[8]; }; +struct __align__(16) RankSignals { volatile Signal* signals[8]; }; // like std::array, but aligned template @@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) { // scalar add functions // for some reason when compiling with Pytorch, the + operator for half and // bfloat is disabled so we call the intrinsics directly -DINLINE half &assign_add(half &a, half b) { +DINLINE half& assign_add(half& a, half b) { a = __hadd(a, b); return a; } -DINLINE float &assign_add(float &a, float b) { return a += b; } +DINLINE float& assign_add(float& a, float b) { return a += b; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } @@ -80,14 +80,14 @@ template <> DINLINE nv_bfloat16 downcast_s(float val) { return __float2bfloat16(val); } -DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { +DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) { a = __hadd(a, b); return a; } #endif template -DINLINE array_t &packed_assign_add(array_t &a, array_t b) { +DINLINE array_t& packed_assign_add(array_t& a, array_t b) { #pragma unroll for (int i = 0; i < N; i++) { assign_add(a.data[i], b.data[i]); @@ -128,7 +128,7 @@ DINLINE O downcast(array_t val) { // prior memory accesses. Note: volatile writes will not be reordered against // other volatile writes. template -DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, +DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, int rank) { if (threadIdx.x < ngpus) { // reset flag for next time @@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, // Latency = 1 p2p write sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; // wait until we got true from all ranks - while (!self_sg->start[blockIdx.x][threadIdx.x]) - ; + while (!self_sg->start[blockIdx.x][threadIdx.x]); } __syncthreads(); } @@ -147,13 +146,13 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, // barrier in the all reduce kernel. If it's the final synchronization barrier, // we don't need to make any visibility guarantees for prior memory accesses. template -DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, +DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, int rank) { __syncthreads(); // eliminate the case that prior writes are not visible after signals become // visible. Note that I did not managed to make this happen through a lot of // testing. Might be the case that hardware provides stronger guarantee than - // the memory model. + // the memory model. if constexpr (!final_sync) __threadfence_system(); if (threadIdx.x < ngpus) { // reset flag for next time @@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, // Latency = 1 p2p write sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; // wait until we got true from all ranks - while (!self_sg->end[blockIdx.x][threadIdx.x]) - ; + while (!self_sg->end[blockIdx.x][threadIdx.x]); } if constexpr (!final_sync) __syncthreads(); } template -DINLINE P packed_reduce(const P *ptrs[], int idx) { +DINLINE P packed_reduce(const P* ptrs[], int idx) { A tmp = upcast(ptrs[0][idx]); #pragma unroll for (int i = 1; i < ngpus; i++) { @@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) { template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage(RankData *_dp, RankSignals sg, - volatile Signal *self_sg, T *__restrict__ result, + cross_device_reduce_1stage(RankData* _dp, RankSignals sg, + volatile Signal* self_sg, T* __restrict__ result, int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; @@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1) // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { - ((P *)result)[idx] = - packed_reduce((const P **)&dp.ptrs[0], idx); + ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); } end_sync(sg, self_sg, rank); } template -DINLINE P *get_tmp_buf(volatile Signal *sg) { - return (P *)(((Signal *)sg) + 1); +DINLINE P* get_tmp_buf(volatile Signal* sg) { + return (P*)(((Signal*)sg) + 1); } template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage(RankData *_dp, RankSignals sg, - volatile Signal *self_sg, T *__restrict__ result, + cross_device_reduce_2stage(RankData* _dp, RankSignals sg, + volatile Signal* self_sg, T* __restrict__ result, int rank, int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; @@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1) int start = rank * part; int end = rank == ngpus - 1 ? size : start + part; int largest_part = part + size % ngpus; - const P *ptrs[ngpus]; - P *tmps[ngpus]; + const P* ptrs[ngpus]; + P* tmps[ngpus]; #pragma unroll for (int i = 0; i < ngpus; i++) { int target = (rank + i) % ngpus; - ptrs[i] = (const P *)_dp->ptrs[target]; + ptrs[i] = (const P*)_dp->ptrs[target]; tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; @@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1) int gather_from_rank = ((rank + i) % ngpus); if (gather_from_rank == ngpus - 1 || idx < part) { int dst_idx = gather_from_rank * part + idx; - ((P *)result)[dst_idx] = tmps[i][idx]; + ((P*)result)[dst_idx] = tmps[i][idx]; } } } @@ -261,14 +258,14 @@ class CustomAllreduce { // below are device pointers RankSignals sg_; - std::unordered_map buffers_; - Signal *self_sg_; + std::unordered_map buffers_; + Signal* self_sg_; // stores the registered device pointers from all ranks RankData *d_rank_data_base_, *d_rank_data_end_; - std::vector graph_unreg_buffers_; + std::vector graph_unreg_buffers_; // a map from IPC handles to opened IPC pointers - std::map ipc_handles_; + std::map ipc_handles_; /** * meta is a pointer to device metadata and temporary buffer for allreduce. @@ -279,22 +276,22 @@ class CustomAllreduce { * note: this class does not own any device memory. Any required buffers * are passed in from the constructor */ - CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz, - const cudaIpcMemHandle_t *handles, - const std::vector &offsets, int rank, + CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, + const cudaIpcMemHandle_t* handles, + const std::vector& offsets, int rank, bool full_nvlink = true) : rank_(rank), world_size_(offsets.size()), full_nvlink_(full_nvlink), self_sg_(meta), - d_rank_data_base_(reinterpret_cast(rank_data)), + d_rank_data_base_(reinterpret_cast(rank_data)), d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { for (int i = 0; i < world_size_; i++) { - Signal *rank_sg; + Signal* rank_sg; if (i != rank_) { - char *handle = open_ipc_handle(&handles[i]); + char* handle = open_ipc_handle(&handles[i]); handle += offsets[i]; - rank_sg = (Signal *)handle; + rank_sg = (Signal*)handle; } else { rank_sg = self_sg_; } @@ -302,13 +299,13 @@ class CustomAllreduce { } } - char *open_ipc_handle(const void *ipc_handle) { + char* open_ipc_handle(const void* ipc_handle) { auto [it, new_handle] = - ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr}); + ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); if (new_handle) { - char *ipc_ptr; - CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr, - *((const cudaIpcMemHandle_t *)ipc_handle), + char* ipc_ptr; + CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr, + *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess)); it->second = ipc_ptr; } @@ -323,7 +320,7 @@ class CustomAllreduce { std::vector offsets(num_buffers); for (int i = 0; i < num_buffers; i++) { auto ptr = graph_unreg_buffers_[i]; - void *base_ptr; + void* base_ptr; // note: must share the base address of each allocation, or we get wrong // address if (cuPointerGetAttribute(&base_ptr, @@ -331,8 +328,8 @@ class CustomAllreduce { (CUdeviceptr)ptr) != CUDA_SUCCESS) throw std::runtime_error("failed to get pointer attr"); CUDACHECK(cudaIpcGetMemHandle( - (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); - offsets[i] = ((char *)ptr) - ((char *)base_ptr); + (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char*)ptr) - ((char*)base_ptr); } return std::make_pair(handles, offsets); } @@ -344,13 +341,13 @@ class CustomAllreduce { std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); } - void register_buffer(const std::vector &handles, - const std::vector &offsets, void *self) { + void register_buffer(const std::vector& handles, + const std::vector& offsets, void* self) { check_rank_data_capacity(); RankData data; for (int i = 0; i < world_size_; i++) { if (i != rank_) { - char *handle = open_ipc_handle(handles[i].data()); + char* handle = open_ipc_handle(handles[i].data()); handle += offsets[i]; data.ptrs[i] = handle; } else { @@ -371,17 +368,17 @@ class CustomAllreduce { // got a different address. IPC handles have internal reference counting // mechanism so overhead should be small. void register_graph_buffers( - const std::vector &handles, - const std::vector> &offsets) { + const std::vector& handles, + const std::vector>& offsets) { auto num_buffers = graph_unreg_buffers_.size(); check_rank_data_capacity(num_buffers); std::vector rank_data(num_buffers); for (int i = 0; i < num_buffers; i++) { auto self_ptr = graph_unreg_buffers_[i]; - auto &rd = rank_data[i]; + auto& rd = rank_data[i]; for (int j = 0; j < world_size_; j++) { if (j != rank_) { - char *handle = + char* handle = open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); handle += offsets[j][i]; rd.ptrs[j] = handle; @@ -405,7 +402,7 @@ class CustomAllreduce { * will cause contention on NVLink bus. */ template - void allreduce(cudaStream_t stream, T *input, T *output, int size, + void allreduce(cudaStream_t stream, T* input, T* output, int size, int threads = 512, int block_limit = 36) { auto d = packed_t::P::size; if (size % d != 0) @@ -418,7 +415,7 @@ class CustomAllreduce { std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit)); - RankData *ptrs; + RankData* ptrs; cudaStreamCaptureStatus status; CUDACHECK(cudaStreamIsCapturing(stream, &status)); if (status == cudaStreamCaptureStatusActive) { diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index c34a50389c21c..f7868233076cd 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -48,7 +48,7 @@ __global__ void dummy_kernel() { } template -__global__ void set_data(T *data, int size, int myRank) { +__global__ void set_data(T* data, int size, int myRank) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { data[idx] = myRank * 0.11f; @@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) { } template -__global__ void convert_data(const T *data1, const T *data2, double *fdata1, - double *fdata2, int size) { +__global__ void convert_data(const T* data1, const T* data2, double* fdata1, + double* fdata2, int size) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { fdata1[idx] = data1[idx]; @@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1, } } -__global__ void init_rand(curandState_t *state, int size, int nRanks) { +__global__ void init_rand(curandState_t* state, int size, int nRanks) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { for (int i = 0; i < nRanks; i++) { @@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) { } template -__global__ void gen_data(curandState_t *state, T *data, double *ground_truth, +__global__ void gen_data(curandState_t* state, T* data, double* ground_truth, int myRank, int nRanks, int size) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { @@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth, } template -void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, +void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, int data_size, bool performance_test) { - T *result; + T* result; cudaStream_t stream; CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); @@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, cudaIpcMemHandle_t self_data_handle; cudaIpcMemHandle_t data_handles[8]; - vllm::Signal *buffer; - T *self_data_copy; + vllm::Signal* buffer; + T* self_data_copy; /** * Allocate IPC buffer * @@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), MPI_BYTE, MPI_COMM_WORLD)); - void *rank_data; + void* rank_data; size_t rank_data_sz = 16 * 1024 * 1024; CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); std::vector offsets(nRanks, 0); vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, offsets, myRank); - auto *self_data = - reinterpret_cast(reinterpret_cast(buffer) + - sizeof(vllm::Signal) + data_size * sizeof(T)); + auto* self_data = + reinterpret_cast(reinterpret_cast(buffer) + + sizeof(vllm::Signal) + data_size * sizeof(T)); // hack buffer registration { std::vector handles; handles.reserve(nRanks); for (int i = 0; i < nRanks; i++) { - char *begin = (char *)&data_handles[i]; - char *end = (char *)&data_handles[i + 1]; + char* begin = (char*)&data_handles[i]; + char* end = (char*)&data_handles[i + 1]; handles.emplace_back(begin, end); } std::vector offsets(nRanks, @@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, fa.register_buffer(handles, offsets, self_data); } - double *ground_truth; + double* ground_truth; CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); - curandState_t *states; + curandState_t* states; CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); gen_data<<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, @@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, CUDACHECK(cudaStreamDestroy(stream)); } -int main(int argc, char **argv) { +int main(int argc, char** argv) { int nRanks, myRank; MPICHECK(MPI_Init(&argc, &argv)); MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); @@ -296,7 +296,7 @@ int main(int argc, char **argv) { ncclUniqueId id; ncclComm_t comm; if (myRank == 0) ncclGetUniqueId(&id); - MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, + MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 91abd9e85b4bb..3ecea03242f06 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -6,32 +6,30 @@ #include -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) -#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) - -#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ +#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) -#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index e56b4d2204005..70a2b3b0a07b1 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -11,26 +11,24 @@ #include #include - using __nv_bfloat16 = __hip_bfloat16; - using __nv_bfloat162 = __hip_bfloat162; +using __nv_bfloat16 = __hip_bfloat16; +using __nv_bfloat162 = __hip_bfloat162; #endif namespace vllm { // TODO(woosuk): Further optimize this kernel. -template +template __global__ void rms_norm_kernel( - scalar_t* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float) input[blockIdx.x * hidden_size + idx]; + const float x = (float)input[blockIdx.x * hidden_size + idx]; variance += x * x; } variance = blockReduceSum(variance); @@ -40,12 +38,12 @@ __global__ void rms_norm_kernel( __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float) input[blockIdx.x * hidden_size + idx]; - out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; + float x = (float)input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; } } - /* Converter structs for the conversion from torch types to HIP/CUDA types, and the associated type conversions within HIP/CUDA. These helpers need to be implemented for now because the relevant type conversion @@ -54,51 +52,68 @@ __global__ void rms_norm_kernel( Each struct should have the member static constexpr bool `exists`: If false, the optimized kernel is not used for the corresponding torch type. - If true, the struct should be fully defined as shown in the examples below. + If true, the struct should be fully defined as shown in the examples below. */ -template -struct _typeConvert { static constexpr bool exists = false; }; +template +struct _typeConvert { + static constexpr bool exists = false; +}; #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) // CUDA < 12.0 runs into issues with packed type conversion -template<> +template <> struct _typeConvert { static constexpr bool exists = true; using hip_type = __half; using packed_hip_type = __half2; __device__ static inline float convert(hip_type x) { return __half2float(x); } - __device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } - __device__ static inline hip_type convert(float x) { return __float2half_rn(x); } - __device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } + __device__ static inline float2 convert(packed_hip_type x) { + return __half22float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2half_rn(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22half2_rn(x); + } }; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // CUDA_ARCH < 800 does not have BF16 support // TODO: Add in ROCm support once public headers handle bf16 maturely -template<> +template <> struct _typeConvert { static constexpr bool exists = true; using hip_type = __nv_bfloat16; using packed_hip_type = __nv_bfloat162; - __device__ static inline float convert(hip_type x) { return __bfloat162float(x); } - __device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } - __device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } - __device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } + __device__ static inline float convert(hip_type x) { + return __bfloat162float(x); + } + __device__ static inline float2 convert(packed_hip_type x) { + return __bfloat1622float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2bfloat16(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22bfloat162_rn(x); + } }; -#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) + #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= + // 12000)) /* Vector POD struct to generate vectorized and packed FP16/BF16 ops for appropriate specializations of fused_add_rms_norm_kernel. Only functions that are necessary in that kernel are implemented. Alignment to 16 bytes is required to use 128-bit global memory ops. */ -template +template struct alignas(16) _f16Vec { - /* Not theoretically necessary that width is a power of 2 but should - almost always be the case for optimization purposes */ + /* Not theoretically necessary that width is a power of 2 but should + almost always be the case for optimization purposes */ static_assert(width > 0 && (width & (width - 1)) == 0, "Width is not a positive power of 2!"); using Converter = _typeConvert; @@ -108,51 +123,49 @@ struct alignas(16) _f16Vec { __device__ _f16Vec& operator+=(const _f16Vec& other) { if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i+1]}; - temp += T2{other.data[i], other.data[i+1]}; + T2 temp{data[i], data[i + 1]}; + temp += T2{other.data[i], other.data[i + 1]}; data[i] = temp.x; - data[i+1] = temp.y; + data[i + 1] = temp.y; } } else { - #pragma unroll - for (int i = 0; i < width; ++i) - data[i] += other.data[i]; +#pragma unroll + for (int i = 0; i < width; ++i) data[i] += other.data[i]; } return *this; } __device__ _f16Vec& operator*=(const _f16Vec& other) { if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i+1]}; - temp *= T2{other.data[i], other.data[i+1]}; + T2 temp{data[i], data[i + 1]}; + temp *= T2{other.data[i], other.data[i + 1]}; data[i] = temp.x; - data[i+1] = temp.y; + data[i + 1] = temp.y; } } else { - #pragma unroll - for (int i = 0; i < width; ++i) - data[i] *= other.data[i]; +#pragma unroll + for (int i = 0; i < width; ++i) data[i] *= other.data[i]; } return *this; } __device__ _f16Vec& operator*=(const float scale) { if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - float2 temp_f = Converter::convert(T2{data[i], data[i+1]}); + float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); temp_f.x *= scale; temp_f.y *= scale; T2 temp = Converter::convert(temp_f); data[i] = temp.x; - data[i+1] = temp.y; + data[i + 1] = temp.y; } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < width; ++i) { float temp = Converter::convert(data[i]) * scale; data[i] = Converter::convert(temp); @@ -164,13 +177,13 @@ struct alignas(16) _f16Vec { __device__ float sum_squares() const { float result = 0.0f; if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - float2 z = Converter::convert(T2{data[i], data[i+1]}); + float2 z = Converter::convert(T2{data[i], data[i + 1]}); result += z.x * z.x + z.y * z.y; } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < width; ++i) { float x = Converter::convert(data[i]); result += x * x; @@ -184,15 +197,13 @@ struct alignas(16) _f16Vec { Additional optimizations we can make in this case are packed and vectorized operations, which help with the memory latency bottleneck. */ -template -__global__ std::enable_if_t< - (width > 0) && _typeConvert::exists> fused_add_rms_norm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] - scalar_t* __restrict__ residual, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { // Sanity checks on our vector struct and type-punned pointer arithmetic static_assert(std::is_pod_v<_f16Vec>); static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); @@ -203,9 +214,12 @@ __global__ std::enable_if_t< /* These and the argument pointers are all declared `restrict` as they are not aliased in practice. Argument pointers should not be dereferenced in this kernel as that would be undefined behavior */ - auto* __restrict__ input_v = reinterpret_cast<_f16Vec*>(input); - auto* __restrict__ residual_v = reinterpret_cast<_f16Vec*>(residual); - auto* __restrict__ weight_v = reinterpret_cast*>(weight); + auto* __restrict__ input_v = + reinterpret_cast<_f16Vec*>(input); + auto* __restrict__ residual_v = + reinterpret_cast<_f16Vec*>(residual); + auto* __restrict__ weight_v = + reinterpret_cast*>(weight); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; @@ -215,10 +229,11 @@ __global__ std::enable_if_t< residual_v[id] = temp; } /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ + calculation of max_block_size in fused_add_rms_norm */ if (num_tokens < 256) { variance = blockReduceSum(variance); - } else variance = blockReduceSum(variance); + } else + variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -233,52 +248,50 @@ __global__ std::enable_if_t< } } - /* Generic fused_add_rms_norm_kernel The width field is not used here but necessary for other specializations. */ -template -__global__ std::enable_if_t< - (width == 0) || !_typeConvert::exists> fused_add_rms_norm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] - scalar_t* __restrict__ residual, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { scalar_t z = input[blockIdx.x * hidden_size + idx]; z += residual[blockIdx.x * hidden_size + idx]; - float x = (float) z; + float x = (float)z; variance += x * x; residual[blockIdx.x * hidden_size + idx] = z; } /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ + calculation of max_block_size in fused_add_rms_norm */ if (num_tokens < 256) { variance = blockReduceSum(variance); - } else variance = blockReduceSum(variance); + } else + variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float) residual[blockIdx.x * hidden_size + idx]; - input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; + float x = (float)residual[blockIdx.x * hidden_size + idx]; + input[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; } } -} // namespace vllm +} // namespace vllm -void rms_norm( - torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] - float epsilon) { +void rms_norm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -286,40 +299,27 @@ void rms_norm( dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "rms_norm_kernel", - [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size); - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, hidden_size); + }); } -#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "fused_add_rms_norm_kernel", \ - [&] { \ - vllm::fused_add_rms_norm_kernel \ - <<>>( \ - input.data_ptr(), \ - residual.data_ptr(), \ - weight.data_ptr(), \ - epsilon, \ - num_tokens, \ - hidden_size); \ - }); - -void fused_add_rms_norm( - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& residual, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] - float epsilon) { +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ + vllm::fused_add_rms_norm_kernel \ + <<>>(input.data_ptr(), \ + residual.data_ptr(), \ + weight.data_ptr(), epsilon, \ + num_tokens, hidden_size); \ + }); + +void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -342,8 +342,8 @@ void fused_add_rms_norm( auto inp_ptr = reinterpret_cast(input.data_ptr()); auto res_ptr = reinterpret_cast(residual.data_ptr()); auto wt_ptr = reinterpret_cast(weight.data_ptr()); - bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \ - && wt_ptr % 16 == 0; + bool ptrs_are_aligned = + inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; if (ptrs_are_aligned && hidden_size % 8 == 0) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { diff --git a/csrc/moe/moe_ops.cpp b/csrc/moe/moe_ops.cpp index 35c328499a22d..4122f7630d7c7 100644 --- a/csrc/moe/moe_ops.cpp +++ b/csrc/moe/moe_ops.cpp @@ -3,5 +3,6 @@ #include PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); + m.def("topk_softmax", &topk_softmax, + "Apply topk softmax to the gating outputs."); } diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index a01be3e426d72..93e7844ac1993 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -2,8 +2,6 @@ #include -void topk_softmax( - torch::Tensor& topk_weights, - torch::Tensor& topk_indices, - torch::Tensor& token_expert_indices, - torch::Tensor& gating_output); +void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, + torch::Tensor& token_expert_indices, + torch::Tensor& gating_output); diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu index e01b23685ef4e..edc441d121029 100644 --- a/csrc/moe_align_block_size_kernels.cu +++ b/csrc/moe_align_block_size_kernels.cu @@ -7,119 +7,128 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#define CEILDIV(x,y) (((x) + (y) - 1) / (y)) +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) namespace vllm { namespace { -__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { - // don't worry about overflow because num_experts is relatively small - return row * total_col + col; -} +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, + int32_t col) { + // don't worry about overflow because num_experts is relatively small + return row * total_col + col; } +} // namespace template -__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, - int32_t *sorted_token_ids, - int32_t *expert_ids, - int32_t *total_tokens_post_pad, - int32_t num_experts, - int32_t block_size, - size_t numel) { - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; - - extern __shared__ int32_t shared_mem[]; - - int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) - int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) - - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; - } - - /** - * In the first step we compute token_cnts[thread_index + 1][expert_index], - * which counts how many tokens in the token shard of thread_index are assigned - * to expert expert_index. - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; - } - - __syncthreads(); - - // For each expert we accumulate the token counts from the different threads. - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; - } - - __syncthreads(); - - // We accumulate the token counts of all experts in thread 0. - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; - } - *total_tokens_post_pad = cumsum[num_experts]; - } - - __syncthreads(); - - /** - * For each expert, each thread processes the tokens of the corresponding blocks - * and stores the corresponding expert_id for each block. - */ - for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) { - expert_ids[i / block_size] = threadIdx.x; +__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, + int32_t* sorted_token_ids, + int32_t* expert_ids, + int32_t* total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, size_t numel) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + extern __shared__ int32_t shared_mem[]; + + int32_t* tokens_cnts = + shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) + int32_t* cumsum = + shared_mem + (num_experts + 1) * + num_experts; // 1d tensor with shape (num_experts + 1) + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are + * assigned to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], + block_size) * + block_size; } - - /** - * Each thread processes a token shard, calculating the index of each token after - * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and - * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], - * where * represents a padding value(preset in python). - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int32_t expert_id = topk_ids[i]; - /** The cumsum[expert_id] stores the starting index of the tokens that the - * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id] - * stores the indices of the tokens processed by the expert with expert_id within - * the current thread's token shard. - */ - int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; - sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; - } -} + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding + * blocks and stores the corresponding expert_id for each block. + */ + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + + /** + * Each thread processes a token shard, calculating the index of each token + * after sorting by expert number. Given the example topk_ids = + * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, + * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a + * padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + } } - -void moe_align_block_size( - torch::Tensor topk_ids, - int num_experts, - int block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors - const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); +} // namespace vllm + +void moe_align_block_size(torch::Tensor topk_ids, int num_experts, + int block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + const int32_t shared_mem = + ((num_experts + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); // set dynamic shared mem auto kernel = vllm::moe_align_block_size_kernel; - AT_CUDA_CHECK( - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem)); + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem)); kernel<<<1, num_experts, shared_mem, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), - num_experts, - block_size, + topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, topk_ids.numel()); - }); + }); } diff --git a/csrc/ops.h b/csrc/ops.h index 8c2c2ae6e1f5a..f5e0e423bb65d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -2,224 +2,136 @@ #include -void paged_attention_v1( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& seq_lens, - int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale); - -void paged_attention_v2( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& seq_lens, - int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale); - -void rms_norm( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& weight, - float epsilon); - -void fused_add_rms_norm( - torch::Tensor& input, - torch::Tensor& residual, - torch::Tensor& weight, - float epsilon); - -void rotary_embedding( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache, - bool is_neox); - -void batched_rotary_embedding( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache, - bool is_neox, - int rot_dim, - torch::Tensor& cos_sin_cache_offsets); - -void silu_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_tanh_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_new( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_fast( - torch::Tensor& out, - torch::Tensor& input); +void paged_attention_v1(torch::Tensor& out, torch::Tensor& query, + torch::Tensor& key_cache, torch::Tensor& value_cache, + int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, + int block_size, int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale); + +void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums, + torch::Tensor& max_logits, torch::Tensor& tmp_out, + torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, + float scale, torch::Tensor& block_tables, + torch::Tensor& seq_lens, int block_size, + int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale); + +void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, + float epsilon); + +void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, + torch::Tensor& weight, float epsilon); + +void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int head_size, + torch::Tensor& cos_sin_cache, bool is_neox); + +void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int head_size, + torch::Tensor& cos_sin_cache, bool is_neox, + int rot_dim, + torch::Tensor& cos_sin_cache_offsets); + +void silu_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_new(torch::Tensor& out, torch::Tensor& input); + +void gelu_fast(torch::Tensor& out, torch::Tensor& input); #ifndef USE_ROCM -torch::Tensor aqlm_gemm( - const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, - const std::optional& bias -); - -torch::Tensor aqlm_dequant( - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes -); - -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters); - -torch::Tensor awq_dequantize( - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters, - int thx, - int thy); - -torch::Tensor marlin_gemm( - torch::Tensor& a, - torch::Tensor& b_q_weight, - torch::Tensor& b_scales, - torch::Tensor& workspace, - int64_t size_m, - int64_t size_n, - int64_t size_k); - -torch::Tensor gptq_marlin_24_gemm( - torch::Tensor &a, - torch::Tensor &b_q_weight, - torch::Tensor &b_meta, - torch::Tensor &b_scales, - torch::Tensor &workspace, - int64_t num_bits, - int64_t size_m, - int64_t size_n, - int64_t size_k); - -torch::Tensor gptq_marlin_gemm( - torch::Tensor &a, - torch::Tensor &b_q_weight, - torch::Tensor &b_scales, - torch::Tensor &g_idx, - torch::Tensor &perm, - torch::Tensor &workspace, - int64_t num_bits, - int64_t size_m, - int64_t size_n, - int64_t size_k, - bool is_k_full); - -torch::Tensor gptq_marlin_repack( - torch::Tensor &b_q_weight, - torch::Tensor &perm, - int64_t size_k, - int64_t size_n, - int64_t num_bits); - -int cutlass_scaled_mm_dq( - torch::Tensor& out, - torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); +torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias); + +torch::Tensor aqlm_dequant(const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes); + +torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int split_k_iters); + +torch::Tensor awq_dequantize(torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int split_k_iters, int thx, + int thy); + +torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t size_m, int64_t size_n, int64_t size_k); + +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k); + +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full); + +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, + int64_t size_k, int64_t size_n, + int64_t num_bits); + +int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales); #endif -void squeezellm_gemm( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor lookup_table); - -torch::Tensor gptq_gemm( - torch::Tensor a, - torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, - torch::Tensor b_g_idx, - bool use_exllama, - int bit); - -void gptq_shuffle( - torch::Tensor q_weight, - torch::Tensor q_perm, - int bit); - -void static_scaled_fp8_quant( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& scale); - -void dynamic_scaled_fp8_quant( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& scale); - -void moe_align_block_size( - torch::Tensor topk_ids, - int num_experts, - int block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad); +void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table); + +torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + bool use_exllama, int bit); + +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit); + +void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& scale); + +void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& scale); + +void moe_align_block_size(torch::Tensor topk_ids, int num_experts, + int block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad); #ifndef USE_ROCM using fptr_t = uint64_t; -fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, - const std::vector &handles, - const std::vector &offsets, int rank, - bool full_nvlink); -bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, int rank, + bool full_nvlink); +bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, bool full_nvlink); -void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); -void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, - torch::Tensor &out); +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, + torch::Tensor& out); void dispose(fptr_t _fa); int meta_size(); -void register_buffer(fptr_t _fa, torch::Tensor &t, - const std::vector &handles, - const std::vector &offsets); -std::pair, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); -void register_graph_buffers(fptr_t _fa, const std::vector &handles, - const std::vector> &offsets); +void register_buffer(fptr_t _fa, torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets); +std::pair, std::vector> get_graph_buffer_ipc_meta( + fptr_t _fa); +void register_graph_buffers(fptr_t _fa, const std::vector& handles, + const std::vector>& offsets); #endif diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index d80cb6973fad6..69d6dae1c26bc 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -7,14 +7,10 @@ namespace vllm { -template +template inline __device__ void apply_token_rotary_embedding( - scalar_t* __restrict__ arr, - const scalar_t* __restrict__ cos_ptr, - const scalar_t* __restrict__ sin_ptr, - int rot_offset, - int embed_dim) -{ + scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) { int x_index, y_index; scalar_t cos, sin; if (IS_NEOX) { @@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding( arr[y_index] = y * cos + x * sin; } -template +template inline __device__ void apply_rotary_embedding( - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* cache_ptr, - const int head_size, - const int num_heads, - const int num_kv_heads, - const int rot_dim, - const int token_idx, - const int64_t query_stride, - const int64_t key_stride) -{ + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* cache_ptr, const int head_size, const int num_heads, + const int num_kv_heads, const int rot_dim, const int token_idx, + const int64_t query_stride, const int64_t key_stride) { const int embed_dim = rot_dim / 2; const scalar_t* cos_ptr = cache_ptr; const scalar_t* sin_ptr = cache_ptr + embed_dim; @@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding( const int head_idx = i / embed_dim; const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_token_rotary_embedding(query + token_head, cos_ptr, - sin_ptr, rot_offset, embed_dim); + apply_token_rotary_embedding( + query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } const int nk = num_kv_heads * embed_dim; @@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding( const int head_idx = i / embed_dim; const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_token_rotary_embedding(key + token_head, cos_ptr, - sin_ptr, rot_offset, embed_dim); + apply_token_rotary_embedding( + key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } } -template +template __global__ void rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] - const int rot_dim, - const int64_t query_stride, - const int64_t key_stride, - const int num_heads, - const int num_kv_heads, - const int head_size) { + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; - apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); + apply_rotary_embedding( + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); } -template +template __global__ void batched_rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] - const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] - const int rot_dim, - const int64_t query_stride, - const int64_t key_stride, - const int num_heads, - const int num_kv_heads, - const int head_size) { + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] + // or [num_tokens] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; - const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; + const scalar_t* cache_ptr = + cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; - apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); + apply_rotary_embedding( + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); } -} // namespace vllm +} // namespace vllm void rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] - int head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox) { + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox) { int64_t num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; @@ -135,36 +141,21 @@ void rotary_embedding( dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - query.scalar_type(), - "rotary_embedding", - [&] { - if (is_neox) { - vllm::rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } else { - vllm::rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } - }); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + vllm::rotary_embedding_kernel<<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, + query_stride, key_stride, num_heads, num_kv_heads, head_size); + } else { + vllm::rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size); + } + }); } /* @@ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together and process in batched manner. */ void batched_rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] - int head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox, - int rot_dim, - torch::Tensor& cos_sin_cache_offsets // [num_tokens] + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox, int rot_dim, + torch::Tensor& cos_sin_cache_offsets // [num_tokens] ) { int64_t num_tokens = cos_sin_cache_offsets.size(0); int num_heads = query.size(-1) / head_size; @@ -191,36 +183,21 @@ void batched_rotary_embedding( dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - query.scalar_type(), - "rotary_embedding", - [&] { - if (is_neox) { - vllm::batched_rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } else { - vllm::batched_rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } - }); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + vllm::batched_rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); + } else { + vllm::batched_rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); + } + }); } diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index f5b4865506568..cba07f0ae9f2a 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -8,116 +8,87 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); // Attention ops - ops.def( - "paged_attention_v1", - &paged_attention_v1, - "Compute the attention between an input query and the cached keys/values using PagedAttention."); - ops.def( - "paged_attention_v2", - &paged_attention_v2, - "PagedAttention V2."); + ops.def("paged_attention_v1", &paged_attention_v1, + "Compute the attention between an input query and the cached " + "keys/values using PagedAttention."); + ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); // Activation ops - ops.def( - "silu_and_mul", - &silu_and_mul, - "Activation function used in SwiGLU."); - ops.def( - "gelu_and_mul", - &gelu_and_mul, - "Activation function used in GeGLU with `none` approximation."); - ops.def( - "gelu_tanh_and_mul", - &gelu_tanh_and_mul, - "Activation function used in GeGLU with `tanh` approximation."); - ops.def( - "gelu_new", - &gelu_new, - "GELU implementation used in GPT-2."); - ops.def( - "gelu_fast", - &gelu_fast, - "Approximate GELU implementation."); + ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); + ops.def("gelu_and_mul", &gelu_and_mul, + "Activation function used in GeGLU with `none` approximation."); + ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, + "Activation function used in GeGLU with `tanh` approximation."); + ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); + ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); // Layernorm - ops.def( - "rms_norm", - &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); + ops.def("rms_norm", &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); - ops.def( - "fused_add_rms_norm", - &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); + ops.def("fused_add_rms_norm", &fused_add_rms_norm, + "In-place fused Add and RMS Normalization"); // Rotary embedding - ops.def( - "rotary_embedding", - &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + ops.def("rotary_embedding", &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); - ops.def( - "batched_rotary_embedding", - &batched_rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)"); + ops.def("batched_rotary_embedding", &batched_rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key " + "(supports multiple loras)"); // Quantization ops #ifndef USE_ROCM ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); - ops.def("marlin_gemm", &marlin_gemm, "Marlin (Dense) Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); + ops.def("marlin_gemm", &marlin_gemm, + "Marlin (Dense) Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, + "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, + "gptq_marlin Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_repack", &gptq_marlin_repack, + "gptq_marlin repack from GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); - ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization."); + ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, + "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or " + "per-row/column quantization."); #endif - + ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); - ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor"); - ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); - ops.def( - "moe_align_block_size", - &moe_align_block_size, - "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); + ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, + "Compute FP8 quantized tensor for given scaling factor"); + ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, + "Compute FP8 quantized tensor and scaling factor"); + ops.def("moe_align_block_size", &moe_align_block_size, + "Aligning the number of tokens to be processed by each expert such " + "that it is divisible by the block size."); // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def( - "swap_blocks", - &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - cache_ops.def( - "copy_blocks", - ©_blocks, - "Copy the cache blocks from src to dst"); - cache_ops.def( - "reshape_and_cache", - &reshape_and_cache, - "Reshape the key and value tensors and cache them"); - cache_ops.def( - "reshape_and_cache_flash", - &reshape_and_cache_flash, - "Reshape the key and value tensors and cache them"); - cache_ops.def( - "convert_fp8", - &convert_fp8, - "Convert the key and value cache to fp8 data type"); + cache_ops.def("swap_blocks", &swap_blocks, + "Swap in (out) the cache blocks from src to dst"); + cache_ops.def("copy_blocks", ©_blocks, + "Copy the cache blocks from src to dst"); + cache_ops.def("reshape_and_cache", &reshape_and_cache, + "Reshape the key and value tensors and cache them"); + cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash, + "Reshape the key and value tensors and cache them"); + cache_ops.def("convert_fp8", &convert_fp8, + "Convert the key and value cache to fp8 data type"); // Cuda utils - pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); - cuda_utils.def( - "get_device_attribute", - &get_device_attribute, - "Gets the specified device attribute."); + pybind11::module cuda_utils = + m.def_submodule("cuda_utils", "vLLM cuda utils"); + cuda_utils.def("get_device_attribute", &get_device_attribute, + "Gets the specified device attribute."); - cuda_utils.def( - "get_max_shared_memory_per_block_device_attribute", - &get_max_shared_memory_per_block_device_attribute, - "Gets the maximum shared memory per block device attribute."); + cuda_utils.def("get_max_shared_memory_per_block_device_attribute", + &get_max_shared_memory_per_block_device_attribute, + "Gets the maximum shared memory per block device attribute."); #ifndef USE_ROCM // Custom all-reduce kernels @@ -134,5 +105,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { custom_ar.def("register_graph_buffers", ®ister_graph_buffers, "register_graph_buffers"); #endif - } diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu index 4415316e1e8cd..255844eec56d4 100644 --- a/csrc/quantization/aqlm/gemm_kernels.cu +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -25,32 +25,28 @@ #include #include - namespace vllm { namespace aqlm { __global__ void Code1x16MatVec( - const int4* __restrict__ A, - const int4* __restrict__ B, - int4* __restrict__ C, - const int4* __restrict__ codebook, - const int prob_m, - const int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. - const int codebook_stride // as int4. + const int4* __restrict__ A, const int4* __restrict__ B, + int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m, + const int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) - { - // advance to the correct codebook, this easy because we only multiply one column of the codebook. + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) - { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; } } @@ -67,8 +63,7 @@ __global__ void Code1x16MatVec( // We pad shared memory to avoid bank conflicts during reads __syncthreads(); for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { - if (b_gl_rd + i < prob_k / 8) - sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; } __syncthreads(); b_gl_rd += 32 * 8; @@ -76,22 +71,19 @@ __global__ void Code1x16MatVec( int b_sh_rd = 9 * (threadIdx.x % 32); if (pred && a_gl_rd < a_gl_end) { const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { uint32_t dec[4]; - // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't - // actually help us; this brings > 2x speedup. - asm volatile ( - "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) - : "l"((void*) &codebook[enc[i]]) - ); + // We bypass the L1 cache to avoid massive amounts of memory streaming + // that doesn't actually help us; this brings > 2x speedup. + asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*)&codebook[enc[i]])); half2* a = reinterpret_cast(&dec); half2* b = reinterpret_cast(&sh_b[b_sh_rd]); half2 res2 = {}; - #pragma unroll - for (int j = 0; j < 4; j++) - res2 = __hfma2(a[j], b[j], res2); +#pragma unroll + for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2); res += __half2float(res2.x) + __half2float(res2.y); b_sh_rd++; } @@ -100,37 +92,33 @@ __global__ void Code1x16MatVec( } if (pred) { - #pragma unroll - for (int i = 16; i > 0; i /= 2) - res += __shfl_down_sync(0xffffffff, res, i); +#pragma unroll + for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); if (threadIdx.x % 32 == 0) reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); } } __global__ void Code2x8MatVec( - const int4* __restrict__ A, - const int4* __restrict__ B, - int4* __restrict__ C, - const int4* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. - const int codebook_stride // as int4. + const int4* __restrict__ A, const int4* __restrict__ B, + int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) - { - // advance to the correct codebook, this easy because we only multiply one column of the codebook. + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) - { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; } } @@ -148,9 +136,8 @@ __global__ void Code2x8MatVec( for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { int4 dec = codebook[i]; - #pragma unroll - for (int j = 0; j < 8; j++) - sh_code[8 * i + (j + lane) % 8] = dec; +#pragma unroll + for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; } __syncthreads(); @@ -161,8 +148,7 @@ __global__ void Code2x8MatVec( // We pad shared memory to avoid bank conflicts during reads __syncthreads(); for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { - if (b_gl_rd + i < prob_k / 8) - sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; } __syncthreads(); b_gl_rd += 32 * 8; @@ -170,13 +156,15 @@ __global__ void Code2x8MatVec( int b_sh_rd = 9 * (threadIdx.x % 32); if (pred && a_gl_rd < a_gl_end) { const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { - half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); - half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); - half2* b = reinterpret_cast(&sh_b[b_sh_rd]); + half2* a0 = + reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = + reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); + half2* b = reinterpret_cast(&sh_b[b_sh_rd]); half2 res2 = {}; - #pragma unroll +#pragma unroll for (int j = 0; j < 4; j++) res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2); res += __half2float(res2.x) + __half2float(res2.y); @@ -187,36 +175,31 @@ __global__ void Code2x8MatVec( } if (pred) { - #pragma unroll - for (int i = 16; i > 0; i /= 2) - res += __shfl_down_sync(0xffffffff, res, i); +#pragma unroll + for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); if (threadIdx.x % 32 == 0) reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); } } - __global__ void Code1x16Dequant( - const int4* __restrict__ A, - int4* __restrict__ C, - const int4* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m. - const int codebook_stride // as int4 + const int4* __restrict__ A, int4* __restrict__ C, + const int4* __restrict__ codebook, int prob_m, int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long, sums to m. + const int codebook_stride // as int4 ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) - { - // advance to the correct codebook, this easy because we only multiply one column of the codebook. + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) - { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; } } @@ -231,17 +214,15 @@ __global__ void Code1x16Dequant( while (iters--) { if (pred && a_gl_rd < a_gl_end) { const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { int4 chunk; auto dec = reinterpret_cast(&chunk); - // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't - // actually help us; this brings > 2x speedup. - asm volatile ( - "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) - : "l"((void*) &codebook[enc[i]]) - ); + // We bypass the L1 cache to avoid massive amounts of memory streaming + // that doesn't actually help us; this brings > 2x speedup. + asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*)&codebook[enc[i]])); C[a_gl_rd * 8 + i] = chunk; } @@ -250,28 +231,25 @@ __global__ void Code1x16Dequant( } } - __global__ void Code2x8Dequant( - const int4* __restrict__ A, - int4* __restrict__ C, - const int4* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. - const int codebook_stride // as int4 + const int4* __restrict__ A, int4* __restrict__ C, + const int4* __restrict__ codebook, int prob_m, int prob_k, + const int4 + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long, corresponds to cols. + const int codebook_stride // as int4 ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) - { - // advance to the correct codebook, this easy because we only multiply one column of the codebook. + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) - { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; } } @@ -290,9 +268,8 @@ __global__ void Code2x8Dequant( for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { int4 dec = codebook[i]; - #pragma unroll - for (int j = 0; j < 8; j++) - sh_code[8 * i + (j + lane) % 8] = dec; +#pragma unroll + for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; } __syncthreads(); @@ -302,12 +279,14 @@ __global__ void Code2x8Dequant( while (iters--) { if (pred && a_gl_rd < a_gl_end) { const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { int4 chunk; - half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); - half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); - #pragma unroll + half2* a0 = + reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = + reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); +#pragma unroll for (int j = 0; j < 4; j++) reinterpret_cast(&chunk)[j] = __hadd2(a0[j], a1[j]); C[a_gl_rd * 8 + i] = chunk; @@ -317,22 +296,15 @@ __global__ void Code2x8Dequant( } } -inline int ceildiv(int a, int b) { - return (a + b - 1) / b; -} +inline int ceildiv(int a, int b) { return (a + b - 1) / b; } const int THREAD_M = 16; -void code1x16_matvec_cuda( - const void* __restrict__ A, - const void* __restrict__ B, - void* __restrict__ C, - const void* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, - const int codebook_stride -) { +void code1x16_matvec_cuda(const void* __restrict__ A, + const void* __restrict__ B, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, + int prob_k, const int4 codebook_a_sizes, + const int codebook_stride) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); int waves = 0; @@ -345,28 +317,16 @@ void code1x16_matvec_cuda( int blocks = ceildiv(prob_m, thread_m); int threads = 32 * thread_m; cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - Code1x16MatVec<<>>( - (const int4*) A, - (const int4*) B, - (int4*) C, - (const int4*) codebook, - prob_m, - prob_k, - codebook_a_sizes, - codebook_stride - ); + Code1x16MatVec<<>>( + (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, + prob_k, codebook_a_sizes, codebook_stride); } -void code2x8_matvec_cuda( - const void* __restrict__ A, - const void* __restrict__ B, - void* __restrict__ C, - const void* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, - const int codebook_stride -) { +void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B, + void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, + int prob_k, const int4 codebook_a_sizes, + const int codebook_stride) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); int waves = 0; @@ -379,30 +339,20 @@ void code2x8_matvec_cuda( int blocks = ceildiv(prob_m, thread_m); int threads = 32 * thread_m; int shared = 16 * (2 * 256 * 8 + 32 * 9); - cudaFuncSetAttribute( - Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared - ); + cudaFuncSetAttribute(Code2x8MatVec, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); Code2x8MatVec<<>>( - (const int4*) A, - (const int4*) B, - (int4*) C, - (const int4*) codebook, - prob_m, - prob_k, - codebook_a_sizes, - codebook_stride - ); + (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, + prob_k, codebook_a_sizes, codebook_stride); } void code1x16_dequant_cuda( - const void* __restrict__ A, - void* __restrict__ C, - const void* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. - const int codebook_stride // as int4. + const void* __restrict__ A, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. ) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); @@ -417,25 +367,21 @@ void code1x16_dequant_cuda( int threads = 32 * thread_m; cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); Code1x16Dequant<<>>( - (const int4*) A, - (int4*) C, - (const int4*) codebook, - prob_m, - prob_k, - codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. - codebook_stride // as int4. + (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long. + codebook_stride // as int4. ); } // Dequantizes the code and codebook into weights. -void code2x8_dequant_cuda( - const void* __restrict__ A, - void* __restrict__ C, - const void* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. - const int codebook_stride // as int4 +void code2x8_dequant_cuda( + const void* __restrict__ A, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, int prob_k, + const int4 + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long, corresponds to cols. + const int codebook_stride // as int4 ) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); @@ -451,74 +397,50 @@ void code2x8_dequant_cuda( int shared = 16 * (2 * 256 * 8 + 32 * 9); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cudaFuncSetAttribute( - Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared - ); + cudaFuncSetAttribute(Code2x8Dequant, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared); Code2x8Dequant<<>>( - (const int4*) A, - (int4*) C, - (const int4*) codebook, - prob_m, - prob_k, - codebook_a_sizes, - codebook_stride - ); + (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, + codebook_a_sizes, codebook_stride); } -int codebook_stride(const torch::Tensor& codebooks) -{ +int codebook_stride(const torch::Tensor& codebooks) { return codebooks.stride(0) * codebooks.element_size() / sizeof(int4); } void code1x16_matvec( - const torch::Tensor& A, - const torch::Tensor& B, - torch::Tensor& C, - const torch::Tensor& codebook, - const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long. + const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C, + const torch::Tensor& codebook, + const int4 codebook_a_sizes // cumulative sizes of A spanning each + // codebook, at most 3 long. ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); int prob_m = C.size(0); int prob_k = B.size(0); - code1x16_matvec_cuda( - A.data_ptr(), - B.data_ptr(), - C.data_ptr(), - codebook.data_ptr(), - prob_m, - prob_k, - codebook_a_sizes, - codebook_stride(codebook) - ); + code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), + codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, + codebook_stride(codebook)); } -torch::Tensor code1x16_matmat( - const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const int4 codebook_a_sizes, - const std::optional& bias) { +torch::Tensor code1x16_matmat(const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias) { auto input_sizes = input.sizes(); auto out_features = codes.size(0) * codebooks.size(2); auto flat_input = input.reshape({-1, input.size(-1)}); - auto flat_output = torch::empty({flat_input.size(0), out_features}, - torch::TensorOptions() - .dtype(input.dtype()) - .device(input.device()) - ); + auto flat_output = torch::empty( + {flat_input.size(0), out_features}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); for (int i = 0; i < flat_input.size(0); ++i) { auto input_vec = flat_input.index({i}); auto output_vec = flat_output.index({i}); - code1x16_matvec( - codes.squeeze(2), - input_vec, - output_vec, - codebooks, - codebook_a_sizes - ); + code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, + codebook_a_sizes); } flat_output *= scales.flatten().unsqueeze(0); @@ -533,55 +455,35 @@ torch::Tensor code1x16_matmat( return output; } -void code2x8_matvec( - const torch::Tensor& A, - const torch::Tensor& B, - torch::Tensor& C, - const torch::Tensor& codebook, - const int4 codebook_a_sizes -) { +void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B, + torch::Tensor& C, const torch::Tensor& codebook, + const int4 codebook_a_sizes) { const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); int prob_m = C.size(0); int prob_k = B.size(0); - code2x8_matvec_cuda( - A.data_ptr(), - B.data_ptr(), - C.data_ptr(), - codebook.data_ptr(), - prob_m, - prob_k, - codebook_a_sizes, - 2 * codebook_stride(codebook) - ); + code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), + codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, + 2 * codebook_stride(codebook)); } -torch::Tensor code2x8_matmat( - const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const int4 codebook_a_sizes, - const std::optional& bias -) { +torch::Tensor code2x8_matmat(const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias) { auto input_sizes = input.sizes(); auto out_features = codes.size(0) * codebooks.size(2); auto flat_input = input.reshape({-1, input.size(-1)}); - auto flat_output = torch::empty({flat_input.size(0), out_features}, - torch::TensorOptions() - .dtype(input.dtype()) - .device(input.device()) - ); + auto flat_output = torch::empty( + {flat_input.size(0), out_features}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); for (int i = 0; i < flat_input.size(0); ++i) { auto input_vec = flat_input.index({i}); auto output_vec = flat_output.index({i}); - code2x8_matvec( - codes.squeeze(2), - input_vec, - output_vec, - codebooks, - codebook_a_sizes - ); + code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, + codebook_a_sizes); } flat_output *= scales.flatten().unsqueeze(0); if (bias.has_value()) { @@ -596,64 +498,56 @@ torch::Tensor code2x8_matmat( } // Accumulate the partition sizes. -int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) -{ +int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) { int4 cumulative_sizes; auto cumulative_size = &cumulative_sizes.x; int i = 0; int last = 0; assert(codebook_partition_sizes.size(0) <= 4); - for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) - { + for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) { *cumulative_size = codebook_partition_sizes[i].item() + last; last = *cumulative_size; } // fill in the rest with unreachable. - for (; i < 4; ++i, ++cumulative_size) - { - *cumulative_size = last*10; + for (; i < 4; ++i, ++cumulative_size) { + *cumulative_size = last * 10; } return cumulative_sizes; } -} // namespace aqlm -} // namespace vllm - +} // namespace aqlm +} // namespace vllm -torch::Tensor aqlm_gemm( - const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, - const std::optional& bias -) -{ - int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); +torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias) { + int4 cumulative_sizes = + vllm::aqlm::accumulate_sizes(codebook_partition_sizes); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const entries = codebooks.size(1); - if (nbooks == 1 && entries == (1 << 16)) - { - return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); + if (nbooks == 1 && entries == (1 << 16)) { + return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, + cumulative_sizes, bias); } - if (nbooks == 2 && entries == (1 << 8)) - { - return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); + if (nbooks == 2 && entries == (1 << 8)) { + return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, + cumulative_sizes, bias); } - TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, + " entries is not currently supported.") return {}; } -torch::Tensor aqlm_dequant( - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes -) -{ - int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); +torch::Tensor aqlm_dequant(const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes) { + int4 cumulative_sizes = + vllm::aqlm::accumulate_sizes(codebook_partition_sizes); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const entries = codebooks.size(1); @@ -668,45 +562,37 @@ torch::Tensor aqlm_dequant( assert(out_features = codebook_partition_sizes.sum().item()); auto weights = torch::empty({out_features, in_features}, - torch::TensorOptions() - .dtype(codebooks.dtype()) - .device(codebooks.device()) - ); + torch::TensorOptions() + .dtype(codebooks.dtype()) + .device(codebooks.device())); + + if (nbooks == 1 && entries == (1 << 16)) { + vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(), + codebooks.data_ptr(), out_features, + in_features, cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); - if (nbooks == 1 && entries == (1 << 16)) - { - vllm::aqlm::code1x16_dequant_cuda( - codes.data_ptr(), - weights.data_ptr(), - codebooks.data_ptr(), - out_features, - in_features, - cumulative_sizes, - vllm::aqlm::codebook_stride(codebooks)); - - // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.) - // weights *= scales.index({"...", 0, 0}); - - return weights; + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower + // and not consistent with gemv implementation.) weights *= + // scales.index({"...", 0, 0}); + + return weights; } - if (nbooks == 2 && entries == (1 << 8)) - { - vllm::aqlm::code2x8_dequant_cuda( - codes.data_ptr(), - weights.data_ptr(), - codebooks.data_ptr(), - out_features, - in_features, - cumulative_sizes, - vllm::aqlm::codebook_stride(codebooks)); - - // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation) - // weights *= scales.index({"...", 0, 0}); - - return weights; + if (nbooks == 2 && entries == (1 << 8)) { + vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(), + codebooks.data_ptr(), out_features, + in_features, cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); + + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower + // and not consistent with gemv implementation) weights *= + // scales.index({"...", 0, 0}); + + return weights; } - TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, + " entries is not currently supported.") return {}; } diff --git a/csrc/quantization/awq/dequantize.cuh b/csrc/quantization/awq/dequantize.cuh index d1d926de18d78..813ec6716cf54 100644 --- a/csrc/quantization/awq/dequantize.cuh +++ b/csrc/quantization/awq/dequantize.cuh @@ -1,11 +1,11 @@ /* Adapted from https://github.com/mit-han-lab/llm-awq -Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +Modified from NVIDIA FasterTransformer: +https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h @article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} + title={AWQ: Activation-aware Weight Quantization for LLM Compression and +Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, +Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ @@ -14,74 +14,88 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor namespace vllm { namespace awq { -__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) -{ +__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 assert(false); #else - uint4 result; + uint4 result; - uint32_t* h = reinterpret_cast(&result); - uint32_t const i4s = reinterpret_cast(source); + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t BOTTOM_MASK = 0x000f000f; - static constexpr uint32_t TOP_MASK = 0x00f000f0; - static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; - // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing - // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. - // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and - // elt_67 to fp16 without having to shift them to the bottom bits before hand. + // Note that the entire sequence only requires 1 shift instruction. This is + // thanks to the register packing format and the fact that we force our + // integers to be unsigned, and account for this in the fp16 subtractions. In + // addition, I exploit the fact that sub and fma have the same throughput in + // order to convert elt_23 and elt_67 to fp16 without having to shift them to + // the bottom bits before hand. - // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue - // immediately before required. - const uint32_t top_i4s = i4s >> 8; - // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[1]) - : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[2]) - : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[3]) - : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW + // dependency if we issue immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); - // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the - // half2 ctor. In this case, I chose performance reliability over code readability. + // I use inline PTX below because I am not sure if the compiler will emit + // float2half instructions if I use the half2 ctor. In this case, I chose + // performance reliability over code readability. - // This is the half2 {1032, 1032} represented as an integer. - // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; - // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] - static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; - // This is the half2 {1 / 16, 1 / 16} represented as an integer. - static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; - // This is the half2 {-72, -72} represented as an integer. - // static constexpr uint32_t NEG_72 = 0xd480d480; - // Haotian: Let's use {-64, -64}. - static constexpr uint32_t NEG_64 = 0xd400d400; + // This is the half2 {1032, 1032} represented as an integer. + // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + // static constexpr uint32_t NEG_72 = 0xd480d480; + // Haotian: Let's use {-64, -64}. + static constexpr uint32_t NEG_64 = 0xd400d400; - // Finally, we construct the output numbers. - // Convert elt_01 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_23 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); - // Convert elt_45 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_67 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[2]) + : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[3]) + : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); - return result; + return result; #endif } -} // namespace awq -} // namespace vllm +} // namespace awq +} // namespace vllm diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 5aefb0bd16aef..bb8e5bbb23d7f 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -1,14 +1,12 @@ /* Adapted from https://github.com/mit-han-lab/llm-awq @article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} + title={AWQ: Activation-aware Weight Quantization for LLM Compression and +Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, +Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ - #include #include @@ -20,26 +18,20 @@ namespace vllm { namespace awq { // Pack two half values. -static inline __device__ __host__ unsigned -__pack_half2(const half x, const half y) { - unsigned v0 = *((unsigned short *)&x); - unsigned v1 = *((unsigned short *)&y); +static inline __device__ __host__ unsigned __pack_half2(const half x, + const half y) { + unsigned v0 = *((unsigned short*)&x); + unsigned v1 = *((unsigned short*)&y); return (v1 << 16) | v0; } -template -__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( - int G, - int split_k_iters, - half* __restrict__ A, - int* __restrict__ B, - half* __restrict__ scaling_factors, - int* __restrict__ zeros, - int M, - int IC, - int OC, - half* __restrict__ C) -{ +template +__global__ void __launch_bounds__(64) + gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters, + half* __restrict__ A, int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, int M, int IC, + int OC, half* __restrict__ C) { // Only support matrix n = 64 or 128 assert(N == 64 || N == 128); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 @@ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( static constexpr int row_stride = 2 * 32 * 8 / N; bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id + bool ld_A_flag = + (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id // bool wb_C_flag = (threadIdx.x / 4) < M; - half* A_ptr = A - + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC - + (((int)threadIdx.x) % (32 / 8)) * 8; - - int* B_ptr = B - + ((int)threadIdx.y) * (OC / 8) * (256 / N) - + (((int)threadIdx.x) / (N / 8)) * (OC / 8) - + (((int)blockIdx_y) % j_factors1) * (N / 8) - + (((int)threadIdx.x) % (N / 8)) * 1; -// Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8) ) * 8; - - half* B_shared_ptr = B_shared - + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) - + (((int)threadIdx.x) / (N / 8)) * (N + 8) - + (((int)threadIdx.x) % (N / 8)) * 8; - - int* zeros_ptr = zeros - + (((int)blockIdx_y) % j_factors1) * (N / 8) - + ((int)threadIdx.x) % (N / 8); - - half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * N - + (((int)threadIdx.x) % (N / 8)) * 8; - - half* C_ptr = C - + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim - + (((int)blockIdx_y) % j_factors1) * N - + ((int)threadIdx.y) * (N / 2) - + (((int)threadIdx.x) % 4) * 2; + half* A_ptr = + A + + (((int)blockIdx_y) / j_factors1 * 16 + + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * + IC + + (((int)threadIdx.x) % (32 / 8)) * 8; + + int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) + + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + (((int)threadIdx.x) % (N / 8)) * 1; + // Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8)) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + + (((int)threadIdx.x) / (N / 8)) * (N + 8) + + (((int)threadIdx.x) % (N / 8)) * 8; + + int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) + + ((int)threadIdx.x) % (N / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * N + + (((int)threadIdx.x) % (N / 8)) * 8; + + half* C_ptr = + C + + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) + + (((int)threadIdx.x) % 4) * 2; // preload s.f. and zeros int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; @@ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; __syncthreads(); // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - if (ld_A_flag) - { + if (ld_A_flag) { *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); - } - else - { + } else { *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); } // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + uint4 B_loaded_scale = + *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); /* - if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ - printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); + if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && + threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, + B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, + B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); } */ // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { - // B: 32 x 136 (128+8) float16 // each warp: 32 x 4 - // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 - // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) - uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus + // zero -> WB UINT4 + // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * + // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) + // * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * + // 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * + // 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * + // 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + uint32_t B_loaded = + *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); + // uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / + // 8)) * 8); - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); + // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x + // % (cta_N / 8)) * 8); // - zero and * scale - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = + // q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); /* - if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ - printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); + if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == + 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n", + B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); } */ // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = + B_loaded_fp16; } __syncthreads(); @@ -173,112 +194,179 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " + "addr; }\n" + : "=r"(addr) + : "l"((void*)((&(A_shared[(k_0_1 * 16)])) + + (((((int)threadIdx.x) & 15) * 40) + + ((((int)threadIdx.x) >> 4) * 8))))); __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) - : "r"(addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned*)(A_shared_warp + 0))[0]), + "=r"(((unsigned*)(A_shared_warp + 0))[1]), + "=r"(((unsigned*)(A_shared_warp + 0))[2]), + "=r"(((unsigned*)(A_shared_warp + 0))[3]) + : "r"(addr)); } for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) - ); + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " + "addr; }\n" + : "=r"(addr) + : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + + (((int)threadIdx.y) * (N / 2))) + + (ax1_0 * 16))])) + + (((((int)threadIdx.x) & 15) * (N + 8)) + + ((((int)threadIdx.x) >> 4) * 8))))); __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) - : "r"(addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr)); } } for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } -#else + #else { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " + "%13};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " + "%13};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } -#endif + #endif } } } -// TODO: Shang: Hoist loop invariance. + // TODO: Shang: Hoist loop invariance. for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; - if (row_offset < M) - { - *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + if (row_offset < M) { + *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); } } } #endif } -__global__ void __launch_bounds__(64) dequantize_weights( - int* __restrict__ B, - half* __restrict__ scaling_factors, - int* __restrict__ zeros, - half* __restrict__ C, - int G -) -{ +__global__ void __launch_bounds__(64) + dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors, + int* __restrict__ zeros, half* __restrict__ C, int G) { int j_factors1 = 4; int row_stride2 = 4; int split_k_iters = 1; @@ -310,14 +398,30 @@ __global__ void __launch_bounds__(64) dequantize_weights( uint32_t B_loaded = *(uint32_t*)B_ptr2; uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); *(uint4*)B_shared_ptr2 = B_loaded_fp16; @@ -326,58 +430,57 @@ __global__ void __launch_bounds__(64) dequantize_weights( } } -} // namespace awq -} // namespace vllm - -torch::Tensor awq_dequantize( - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters, - int thx, - int thy) -{ - int in_c = _kernel.size(0); - int qout_c = _kernel.size(1); - int out_c = qout_c * 8; - int G = in_c / _scaling_factors.size(0); - - int x_thread = thx; - int y_thread = thy; - - int x_blocks = 1; - int y_blocks = 1; - if (thx==0) { - x_thread = qout_c; - } - if (thy==0) { - y_thread = in_c; - } - if (thx==0 && thy==0) { - x_thread = 8; - y_thread = 8; - x_blocks = (int)(qout_c / 8); - y_blocks = (int)(in_c / 8); - } +} // namespace awq +} // namespace vllm + +torch::Tensor awq_dequantize(torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int split_k_iters, int thx, + int thy) { + int in_c = _kernel.size(0); + int qout_c = _kernel.size(1); + int out_c = qout_c * 8; + int G = in_c / _scaling_factors.size(0); + + int x_thread = thx; + int y_thread = thy; + + int x_blocks = 1; + int y_blocks = 1; + if (thx == 0) { + x_thread = qout_c; + } + if (thy == 0) { + y_thread = in_c; + } + if (thx == 0 && thy == 0) { + x_thread = 8; + y_thread = 8; + x_blocks = (int)(qout_c / 8); + y_blocks = (int)(in_c / 8); + } - const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); - auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); - at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); + auto options = torch::TensorOptions() + .dtype(_scaling_factors.dtype()) + .device(_scaling_factors.device()); + at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); + auto scaling_factors = + reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); - dim3 num_blocks(x_blocks, y_blocks); - dim3 threads_per_block(x_thread, y_thread); + dim3 num_blocks(x_blocks, y_blocks); + dim3 threads_per_block(x_thread, y_thread); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - vllm::awq::dequantize_weights<<>>( - kernel, scaling_factors, zeros, de_kernel, G); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::awq::dequantize_weights<<>>( + kernel, scaling_factors, zeros, de_kernel, G); - return _de_kernel; + return _de_kernel; } // in_feats: M, IC [float16] @@ -386,61 +489,61 @@ torch::Tensor awq_dequantize( // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] // assume that batch_size < 16 for now -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters) -{ - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); - - auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); - at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - int group_size = num_in_channels / _scaling_factors.size(0); - - if (num_out_channels % 64 != 0) - throw std::invalid_argument("OC is not multiple of cta_N = 64"); - if (num_out_channels % 8 != 0) - throw std::invalid_argument("OC is not multiple of pack_num = 8"); - if (group_size % 32 != 0) - throw std::invalid_argument("Group size should be a multiple of 32"); - if (num_out_channels % group_size != 0) - throw std::invalid_argument("OC is not multiple of Group size"); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (num_out_channels % 128 == 0) - { - int j_factors1 = num_out_channels / 128 / 1; - dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, - num_out_channels, out_feats); - } - else if (num_out_channels % 64 == 0) - { - int j_factors1 = num_out_channels / 64 / 1; - dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, - num_out_channels, out_feats); - } - return _out_feats.sum(0); +torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int split_k_iters) { + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + + auto options = torch::TensorOptions() + .dtype(_in_feats.dtype()) + .device(_in_feats.device()); + at::Tensor _out_feats = + torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto scaling_factors = + reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + int group_size = num_in_channels / _scaling_factors.size(0); + + if (num_out_channels % 64 != 0) + throw std::invalid_argument("OC is not multiple of cta_N = 64"); + if (num_out_channels % 8 != 0) + throw std::invalid_argument("OC is not multiple of pack_num = 8"); + if (group_size % 32 != 0) + throw std::invalid_argument("Group size should be a multiple of 32"); + if (num_out_channels % group_size != 0) + throw std::invalid_argument("OC is not multiple of Group size"); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (num_out_channels % 128 == 0) { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128> + <<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } else if (num_out_channels % 64 == 0) { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * + split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64> + <<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } + return _out_feats.sum(0); } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu index 3ec454f78c654..e62fe731a98d3 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu @@ -117,10 +117,10 @@ struct cutlass_2x_gemm { }; template -void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; @@ -136,9 +136,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, using StrideC = Stride, Int<0>>; StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto c_ptr = static_cast(out.data_ptr()); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); auto a_scales_ptr = a_scales.data_ptr(); auto b_scales_ptr = b_scales.data_ptr(); @@ -196,10 +196,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, } // namespace -void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); @@ -223,10 +223,10 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, } } -void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); @@ -250,10 +250,10 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, } } -void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu index 37b096de23e3b..12efcac7bb919 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -120,10 +120,10 @@ struct cutlass_3x_gemm { }; template -void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; @@ -146,12 +146,12 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, using GemmKernel = typename Gemm::GemmKernel; typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, b_stride}; - auto c_ptr = static_cast(out.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ {}, c_ptr, c_stride, c_ptr, c_stride}; @@ -183,10 +183,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, } } // namespace -void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu index a4e696d4a3322..dab73ac6c831e 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu @@ -2,29 +2,29 @@ #include #include -void cutlass_scaled_mm_dq_sm75(torch::Tensor &c, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); +void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); -void cutlass_scaled_mm_dq_sm80(torch::Tensor &c, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); +void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); -void cutlass_scaled_mm_dq_sm89(torch::Tensor &c, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); +void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); -void cutlass_scaled_mm_dq_sm90(torch::Tensor &c, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); +void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); -void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a, - torch::Tensor const &b, torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { int32_t major_capability; int32_t minor_capability; cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, @@ -36,14 +36,15 @@ void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a, // Checks for conformality TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); + b.size(1) == c.size(1)); TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); diff --git a/csrc/quantization/fp8/amd/hip_float8.h b/csrc/quantization/fp8/amd/hip_float8.h index 87c7c9ce66100..f9c80fcdec576 100644 --- a/csrc/quantization/fp8/amd/hip_float8.h +++ b/csrc/quantization/fp8/amd/hip_float8.h @@ -1,167 +1,137 @@ #pragma once #ifdef __HIPCC__ -#include + #include #else -#include -#include -#include -#include + #include + #include + #include + #include #endif #include "hip_float8_impl.h" -struct alignas(1) hip_fp8 -{ - struct from_bits_t - { - }; - HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); } - uint8_t data; - - hip_fp8() = default; - HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; - HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; - explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) - : data(v) - { - } +struct alignas(1) hip_fp8 { + struct from_bits_t {}; + HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + uint8_t data; + + hip_fp8() = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; + explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) + : data(v) {} #ifdef __HIP__MI300__ - // NOTE: ON-DEVICE... always optimal bias - explicit HIP_FP8_DEVICE hip_fp8(float v) - : data(hip_fp8_impl::to_fp8_from_fp32(v)) - { - } - - explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) - : hip_fp8(static_cast(v)) - { - } - - // Host only implementation using s/w simulation - explicit HIP_FP8_HOST -#else // __HIP__MI300__ - // both Host and DEVICE for non-MI300 using s/w simulation - explicit HIP_FP8_HOST_DEVICE -#endif // __HIP__MI300__ - hip_fp8(float v) - { - data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v); - } - - explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) - : hip_fp8(static_cast(v)) - { - } + // NOTE: ON-DEVICE... always optimal bias + explicit HIP_FP8_DEVICE hip_fp8(float v) + : data(hip_fp8_impl::to_fp8_from_fp32(v)) {} + + explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) + : hip_fp8(static_cast(v)) {} + + // Host only implementation using s/w simulation + explicit HIP_FP8_HOST +#else // __HIP__MI300__ + // both Host and DEVICE for non-MI300 using s/w simulation + explicit HIP_FP8_HOST_DEVICE +#endif // __HIP__MI300__ + hip_fp8(float v) { + data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, + true /*clip*/>(v); + } + + explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) + : hip_fp8(static_cast(v)) {} #ifdef __HIP__MI300__ - // upcast using device specific intrinsic - explicit inline HIP_FP8_DEVICE operator float() const - { - float fval; - uint32_t i32val = static_cast(data); - - // upcast - asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); - - return fval; - } - - explicit inline HIP_FP8_HOST operator float() const -#else // __HIP__MI300__ - explicit inline HIP_FP8_HOST_DEVICE operator float() const -#endif // __HIP__MI300__ - { - return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data); - } + // upcast using device specific intrinsic + explicit inline HIP_FP8_DEVICE operator float() const { + float fval; + uint32_t i32val = static_cast(data); + + // upcast + asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" + : "=v"(fval) + : "v"(i32val)); + + return fval; + } + + explicit inline HIP_FP8_HOST operator float() const +#else // __HIP__MI300__ + explicit inline HIP_FP8_HOST_DEVICE operator float() const +#endif // __HIP__MI300__ + { + return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>( + data); + } }; -namespace std -{ -inline hip_fp8 sin(hip_fp8 a) -{ - return hip_fp8(sinf(float(a))); -} -inline hip_fp8 cos(hip_fp8 a) -{ - return hip_fp8(cosf(float(a))); -} -HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) -{ - return a; -} -} // namespace std +namespace std { +inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); } +inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); } +HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; } +} // namespace std // Special operator overloading -inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) -{ - return os << float(f8); +inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) { + return os << float(f8); } // all + operator overloading with mixed types -// mixed types, always converts to f32, does computation in f32, and returns float -inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) -{ - return (fa + float(b)); +// mixed types, always converts to f32, does computation in f32, and returns +// float +inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) { + return (fa + float(b)); } -inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) -{ - return (float(a) + fb); +inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) { + return (float(a) + fb); } -inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) -{ - return hip_fp8(float(a) + float(b)); +inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) { + return hip_fp8(float(a) + float(b)); } -inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) -{ - return a = hip_fp8(float(a) + float(b)); +inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) { + return a = hip_fp8(float(a) + float(b)); } // overloading multiplication, always returns float, -inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) -{ - return float(a) * float(b); +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) { + return float(a) * float(b); } -inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) -{ - return (a * float(b)); +inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) { + return (a * float(b)); } -inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) -{ - return (float(a) * b); +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) { + return (float(a) * b); } -inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) -{ - return ((float)a * float(b)); +inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) { + return ((float)a * float(b)); } -inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) -{ - return ((float)a * float(b)); +inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) { + return ((float)a * float(b)); } // overloading for compare -inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) -{ - return (a.data == b.data); +inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) { + return (a.data == b.data); } -inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) -{ - return (a.data != b.data); +inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) { + return (a.data != b.data); } -inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) -{ - return static_cast(a) >= static_cast(b); +inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) { + return static_cast(a) >= static_cast(b); } -inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) -{ - return static_cast(a) > static_cast(b); +inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) { + return static_cast(a) > static_cast(b); } diff --git a/csrc/quantization/fp8/amd/hip_float8_impl.h b/csrc/quantization/fp8/amd/hip_float8_impl.h index e05905b4e49e8..90251c3539534 100644 --- a/csrc/quantization/fp8/amd/hip_float8_impl.h +++ b/csrc/quantization/fp8/amd/hip_float8_impl.h @@ -1,316 +1,316 @@ #pragma once -#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) -#define __HIP__MI300__ +#if defined(__HIPCC__) && \ + (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300__ #endif #ifdef __HIPCC__ -#define HIP_FP8_HOST_DEVICE __host__ __device__ -#define HIP_FP8_HOST __host__ -#define HIP_FP8_DEVICE __device__ + #define HIP_FP8_HOST_DEVICE __host__ __device__ + #define HIP_FP8_HOST __host__ + #define HIP_FP8_DEVICE __device__ #else -#define HIP_FP8_HOST_DEVICE -#define HIP_FP8_HOST -#define HIP_FP8_DEVICE + #define HIP_FP8_HOST_DEVICE + #define HIP_FP8_HOST + #define HIP_FP8_DEVICE #endif -namespace hip_fp8_impl -{ +namespace hip_fp8_impl { #ifdef __HIP__MI300__ -HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) -{ - uint8_t i8data; - union { - float fval; - uint32_t i32val; - uint8_t i8val[4]; // NOTE: not endian independent - } val; - - uint32_t ival = 0; - val.fval = v; - - if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping - val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); - } - - ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, - false); // false -> WORD0 - val.i32val = ival; - i8data = val.i8val[0]; - - return i8data; +HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) { + uint8_t i8data; + union { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // NOTE: not endian independent + } val; + + uint32_t ival = 0; + val.fval = v; + + if ((val.i32val & 0x7F800000) != + 0x7F800000) { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); + } + + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, + false); // false -> WORD0 + val.i32val = ival; + i8data = val.i8val[0]; + + return i8data; } -#endif // __HIP__MI300__ +#endif // __HIP__MI300__ -HIP_FP8_HOST inline int clz(uint32_t x) -{ - return __builtin_clz(x); -} +HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); } #if defined(__HIPCC__) || defined(__CUDA_ARCH__) -HIP_FP8_DEVICE inline int clz(uint32_t x) -{ - return __clz(x); -} +HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); } #endif template -HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0) -{ +HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, + uint32_t rng = 0) { #ifdef __HIPCC__ - constexpr bool is_half = std::is_same::value; + constexpr bool is_half = std::is_same::value; #else - constexpr bool is_half = false; + constexpr bool is_half = false; #endif - constexpr bool is_float = std::is_same::value; - static_assert(wm + we == 7, "wm+we==7"); - static_assert(is_half || is_float, "Only half and float can be cast to f8"); - - const int mfmt = (sizeof(T) == 4) ? 23 : 10; - uint32_t x; + constexpr bool is_float = std::is_same::value; + static_assert(wm + we == 7, "wm+we==7"); + static_assert(is_half || is_float, "Only half and float can be cast to f8"); + + const int mfmt = (sizeof(T) == 4) ? 23 : 10; + uint32_t x; + if (sizeof(T) == 4) { + x = reinterpret_cast(_x); + } else { + x = reinterpret_cast(_x); + } + + uint32_t head, mantissa; + int exponent, bias; + uint32_t sign; + + if (sizeof(T) == 4) { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + } else { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + } + + uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); + + // Deal with inf and NaNs + if (negative_zero_nan) { if (sizeof(T) == 4) { - x = reinterpret_cast(_x); + if ((x & 0x7F800000) == 0x7F800000) { + return 0x80; + } } else { - x = reinterpret_cast(_x); + // if(__hisinf(x) || __hisnan(x)) + if ((x & 0x7C00) == 0x7C00) { + return 0x80; + } } - - uint32_t head, mantissa; - int exponent, bias; - uint32_t sign; - + } else { if (sizeof(T) == 4) { - head = x & 0xFF800000; - mantissa = x & 0x7FFFFF; - exponent = (head >> 23) & 0xFF; - sign = head >> 31; - bias = 127; + if ((x & 0x7F800000) == 0x7F800000) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } } else { - head = x & 0xFC00; - mantissa = x & 0x3FF; - exponent = (head >> 10) & 0x1F; - sign = head >> 15; - bias = 15; + if ((x & 0x7C00) == 0x7C00) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } } - - uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); - - // Deal with inf and NaNs - if (negative_zero_nan) { - if (sizeof(T) == 4) { - if ((x & 0x7F800000) == 0x7F800000) { - return 0x80; - } - } else { - // if(__hisinf(x) || __hisnan(x)) - if ((x & 0x7C00) == 0x7C00) { - return 0x80; - } - } - } else { - if (sizeof(T) == 4) { - if ((x & 0x7F800000) == 0x7F800000) { - return signed_inf + (mantissa != 0 ? 1 : 0); - } - } else { - if ((x & 0x7C00) == 0x7C00) { - return signed_inf + (mantissa != 0 ? 1 : 0); - } - } - } - if (x == 0) { - return 0; - } - - // First need to check if it is normal or denorm as there is a difference of - // implicit 1 Then need to adjust the exponent to align with the F8 exponent, - // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng - // to mantissa and truncate. And for RNE, no need to add rng. Then probably - // need to check whether there is carry and adjust exponent and mantissa again - - // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent - // bits - const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); - const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal - // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) - // f8_exponent is the converted f8 exponent with bias encoding - // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, - // the difference needs to be adjusted and mantissa shifted - int act_exponent, f8_exponent, exponent_diff; - - if (exponent == 0) { // fp32/fp16 is in denormal. - /* fp32 denormal is below 2^-127 so it is usually not a concern here, we + } + if (x == 0) { + return 0; + } + + // First need to check if it is normal or denorm as there is a difference of + // implicit 1 Then need to adjust the exponent to align with the F8 exponent, + // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng + // to mantissa and truncate. And for RNE, no need to add rng. Then probably + // need to check whether there is carry and adjust exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent + // bits + const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int f8_denormal_act_exponent = + 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if (exponent == 0) { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ - act_exponent = exponent - bias + 1; - exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal - } else { // fp32/fp16 is normal with implicit 1 - act_exponent = exponent - bias; - if (act_exponent <= f8_denormal_act_exponent) { - /* This is the case where fp32/fp16 is normal but it is in f8 denormal - range. For example fp8 nanoo mode, denormal exponent is -7, but if the - fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, - Therefore it needs to be adjust to -6 and mantissa shift right by 1. - So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ - exponent_diff = f8_denormal_act_exponent - act_exponent; - } else { // both fp32/fp16 and f8 are in normal range - exponent_diff = 0; // exponent_diff=0 does not mean there is no difference - // for this case, - // act_exponent could be larger. Just that it does not need shift mantissa - } - mantissa += (1 << mfmt); // Add the implicit 1 into mantissa + act_exponent = exponent - bias + 1; + exponent_diff = + f8_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } else { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if (act_exponent <= f8_denormal_act_exponent) { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal +range. For example fp8 nanoo mode, denormal exponent is -7, but if the +fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, +Therefore it needs to be adjust to -6 and mantissa shift right by 1. +So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } else { // both fp32/fp16 and f8 are in normal range + exponent_diff = 0; // exponent_diff=0 does not mean there is no + // difference for this case, act_exponent could be + // larger. Just that it does not need shift mantissa } - - bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == - static_cast(1 << (mfmt - wm + exponent_diff - 1)); - /* This part is a bit tricky. The judgment of whether it is a tie needs to be - done before we shift right as shift right could rip off some residual part - and make something not midpoint look like midpoint. For example, the fp16 - number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after - shift right by 4 bits, it would look like midpoint. + mantissa += (1 << mfmt); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == + static_cast(1 << (mfmt - wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be + done before we shift right as shift right could rip off some residual part + and make something not midpoint look like midpoint. For example, the fp16 + number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after + shift right by 4 bits, it would look like midpoint. */ - if (exponent_diff > 0) { - mantissa >>= exponent_diff; - } else if (exponent_diff == -1) { - mantissa <<= -exponent_diff; + if (exponent_diff > 0) { + mantissa >>= exponent_diff; + } else if (exponent_diff == -1) { + mantissa <<= -exponent_diff; + } + bool implicit_one = mantissa & (1 << mfmt); + // if there is no implicit 1, it means the f8 is denormal and need to adjust + // to denorm exponent + f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + uint32_t drop_mask = (1 << (mfmt - wm)) - 1; + bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit + // that is not truncated is 1 + mantissa += + (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & + drop_mask; + + // Now we deal with overflow + if (f8_exponent == 0) { + if ((1 << mfmt) & mantissa) { + f8_exponent = 1; // denormal overflow to become normal, promote exponent } - bool implicit_one = mantissa & (1 << mfmt); - // if there is no implicit 1, it means the f8 is denormal and need to adjust - // to denorm exponent - f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); - - // Now we have the exponent and mantissa adjusted - uint32_t drop_mask = (1 << (mfmt - wm)) - 1; - bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that - // is not truncated is 1 - mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; - - // Now we deal with overflow - if (f8_exponent == 0) { - if ((1 << mfmt) & mantissa) { - f8_exponent = 1; // denormal overflow to become normal, promote exponent - } - } else { - if ((1 << (mfmt + 1)) & mantissa) { - mantissa >>= 1; - f8_exponent++; - } + } else { + if ((1 << (mfmt + 1)) & mantissa) { + mantissa >>= 1; + f8_exponent++; } + } - mantissa >>= (mfmt - wm); - - // above range: quantize to maximum possible float of the same sign - const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); - if (f8_exponent > max_exp) { - if (clip) { - mantissa = (1 << wm) - 1; - f8_exponent = max_exp; - } else { - return signed_inf; - } - } + mantissa >>= (mfmt - wm); - if (f8_exponent == 0 && mantissa == 0) { - return negative_zero_nan ? 0 : (sign << 7); + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); + if (f8_exponent > max_exp) { + if (clip) { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } else { + return signed_inf; } - mantissa &= (1 << wm) - 1; - return (sign << 7) | (f8_exponent << wm) | mantissa; + } + + if (f8_exponent == 0 && mantissa == 0) { + return negative_zero_nan ? 0 : (sign << 7); + } + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; } template -inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) -{ +inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) { #ifdef __HIPCC__ - constexpr bool is_half = std::is_same::value; + constexpr bool is_half = std::is_same::value; #else - constexpr bool is_half = false; + constexpr bool is_half = false; #endif - constexpr bool is_float = std::is_same::value; - static_assert(is_half || is_float, "only half and float are supported"); + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "only half and float are supported"); - constexpr int weo = is_half ? 5 : 8; - constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); + constexpr int weo = is_half ? 5 : 8; + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); - T fInf, fNegInf, fNaN, fNeg0; + T fInf, fNegInf, fNaN, fNeg0; #ifdef __HIPCC__ - if (is_half) { - const uint16_t ihInf = 0x7C00; - const uint16_t ihNegInf = 0xFC00; - const uint16_t ihNaN = 0x7C01; - const uint16_t ihNeg0 = 0x8000; - fInf = reinterpret_cast(ihInf); - fNegInf = reinterpret_cast(ihNegInf); - fNaN = reinterpret_cast(ihNaN); - fNeg0 = reinterpret_cast(ihNeg0); - } else + if (is_half) { + const uint16_t ihInf = 0x7C00; + const uint16_t ihNegInf = 0xFC00; + const uint16_t ihNaN = 0x7C01; + const uint16_t ihNeg0 = 0x8000; + fInf = reinterpret_cast(ihInf); + fNegInf = reinterpret_cast(ihNegInf); + fNaN = reinterpret_cast(ihNaN); + fNeg0 = reinterpret_cast(ihNeg0); + } else #endif - if (is_float) { - const uint32_t ifInf = 0x7F800000; - const uint32_t ifNegInf = 0xFF800000; - const uint32_t ifNaN = 0x7F800001; - const uint32_t ifNeg0 = 0x80000000; - fInf = reinterpret_cast(ifInf); - fNegInf = reinterpret_cast(ifNegInf); - fNaN = reinterpret_cast(ifNaN); - fNeg0 = reinterpret_cast(ifNeg0); - } - - if (x == 0) { - return 0; - } - - uint32_t sign = x >> 7; - uint32_t mantissa = x & ((1 << wm) - 1); - int exponent = (x & 0x7F) >> wm; - if (negative_zero_nan) { - if (x == 0x80) { - return fNaN; - } - } else { - if (x == 0x80) { - return fNeg0; - } - if (exponent == ((1 << we) - 1)) { - return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; - } - } - typename std::conditional::type retval; - if (we == 5 && is_half && !negative_zero_nan) { - retval = x << 8; - return reinterpret_cast(retval); + if (is_float) { + const uint32_t ifInf = 0x7F800000; + const uint32_t ifNegInf = 0xFF800000; + const uint32_t ifNaN = 0x7F800001; + const uint32_t ifNeg0 = 0x80000000; + fInf = reinterpret_cast(ifInf); + fNegInf = reinterpret_cast(ifNegInf); + fNaN = reinterpret_cast(ifNaN); + fNeg0 = reinterpret_cast(ifNeg0); + } + + if (x == 0) { + return 0; + } + + uint32_t sign = x >> 7; + uint32_t mantissa = x & ((1 << wm) - 1); + int exponent = (x & 0x7F) >> wm; + if (negative_zero_nan) { + if (x == 0x80) { + return fNaN; } - - const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); - - // subnormal input - if (exponent == 0) { - // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above - int sh = 1 + clz(mantissa) - (32 - wm); - mantissa <<= sh; - exponent += 1 - sh; - mantissa &= ((1 << wm) - 1); + } else { + if (x == 0x80) { + return fNeg0; } - exponent += exp_low_cutoff - 1; - mantissa <<= wmo - wm; - - // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) - if (exponent <= 0) { - mantissa |= 1 << wmo; - mantissa >>= 1 - exponent; - exponent = 0; - } - - if (sizeof(T) == 2) { - retval = (sign << 15) | (exponent << 10) | mantissa; - } else { - retval = (sign << 31) | (exponent << 23) | mantissa; + if (exponent == ((1 << we) - 1)) { + return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; } + } + typename std::conditional::type retval; + if (we == 5 && is_half && !negative_zero_nan) { + retval = x << 8; return reinterpret_cast(retval); + } + + const int exp_low_cutoff = + (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); + + // subnormal input + if (exponent == 0) { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + clz(mantissa) - (32 - wm); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << wm) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if (exponent <= 0) { + mantissa |= 1 << wmo; + mantissa >>= 1 - exponent; + exponent = 0; + } + + if (sizeof(T) == 2) { + retval = (sign << 15) | (exponent << 10) | mantissa; + } else { + retval = (sign << 31) | (exponent << 23) | mantissa; + } + return reinterpret_cast(retval); } -} // namespace hip_fp8_impl +} // namespace hip_fp8_impl diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh index df0329f79d361..35123d7fc65d4 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -9,566 +9,567 @@ #include "../../../attention/dtype_float32.cuh" #include "../../../attention/dtype_bfloat16.cuh" -namespace vllm -{ +namespace vllm { #ifdef USE_ROCM namespace fp8 { -#ifdef ENABLE_FP8 + #ifdef ENABLE_FP8 template -__inline__ __device__ Tout vec_conversion(const Tin& x) -{ - return x; +__inline__ __device__ Tout vec_conversion(const Tin& x) { + return x; } template -__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale) -{ - return x; +__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, + const float scale) { + return x; } // fp8 -> half template <> -__inline__ __device__ uint16_t vec_conversion(const uint8_t& a) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - __half_raw res; - res.data = static_cast(f8); - return res.x; +__inline__ __device__ uint16_t +vec_conversion(const uint8_t& a) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8); + return res.x; } // fp8x2 -> half2 template <> -__inline__ __device__ uint32_t vec_conversion(const uint16_t& a) -{ -#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - union { - __half2_raw h2r; - uint32_t ui32; - } tmp; - tmp.h2r.x.data = f2[0]; - tmp.h2r.y.data = f2[1]; - return tmp.ui32; -#else - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - - tmp.u16[0] = vec_conversion(static_cast(a)); - tmp.u16[1] = vec_conversion(static_cast(a >> 8U)); - return tmp.u32; -#endif +__inline__ __device__ uint32_t +vec_conversion(const uint16_t& a) { + #if defined(__HIP__MI300__) && \ + defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0]; + tmp.h2r.y.data = f2[1]; + return tmp.ui32; + #else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = vec_conversion(static_cast(a)); + tmp.u16[1] = vec_conversion(static_cast(a >> 8U)); + return tmp.u32; + #endif } // fp8x4 -> half2x2 template <> -__inline__ __device__ uint2 vec_conversion(const uint32_t& a) -{ - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = vec_conversion((uint16_t)a); - tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); - return tmp.u32x2; +__inline__ __device__ uint2 vec_conversion(const uint32_t& a) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a); + tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); + return tmp.u32x2; } // fp8x8 -> half2x4 template <> -__inline__ __device__ uint4 vec_conversion(const uint2& a) -{ - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = vec_conversion(a.x); - tmp.u64[1] = vec_conversion(a.y); - return tmp.u64x2; +__inline__ __device__ uint4 vec_conversion(const uint2& a) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x); + tmp.u64[1] = vec_conversion(a.y); + return tmp.u64x2; } using __nv_bfloat16 = __hip_bfloat16; // fp8 -> __nv_bfloat16 template <> -__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - float f{f8}; - return __float2bfloat16(f); +__inline__ __device__ __nv_bfloat16 +vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f); } using __nv_bfloat162 = __hip_bfloat162; // fp8x2 -> __nv_bfloat162 template <> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) -{ - __nv_bfloat162 res; - res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); - res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); - return res; +__inline__ __device__ __nv_bfloat162 +vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) { + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); + return res; } // fp8x4 -> bf16_4_t template <> -__inline__ __device__ bf16_4_t vec_conversion(const uint32_t& a) -{ - bf16_4_t res; - res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); - res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); - return res; +__inline__ __device__ bf16_4_t +vec_conversion(const uint32_t& a) { + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); + res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); + return res; } // fp8x8 -> bf16_8_t template <> -__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) -{ - bf16_4_t tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - bf16_8_t res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) { + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } // fp8 -> float template <> -__inline__ __device__ float vec_conversion(const uint8_t& a) -{ - hip_fp8 fp8{a, hip_fp8::from_bits()}; - return static_cast(fp8); +__inline__ __device__ float vec_conversion(const uint8_t& a) { + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8); } // fp8x2 -> float2 template <> -__inline__ __device__ float2 vec_conversion(const uint16_t& a) -{ -#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - float2 res; - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - res.x = f2[0]; - res.y = f2[1]; - return res; -#else - float2 res; - res.x = vec_conversion(static_cast(a)); - res.y = vec_conversion(static_cast(a >> 8U)); - return res; -#endif +__inline__ __device__ float2 +vec_conversion(const uint16_t& a) { + #if defined(__HIP__MI300__) && \ + defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0]; + res.y = f2[1]; + return res; + #else + float2 res; + res.x = vec_conversion(static_cast(a)); + res.y = vec_conversion(static_cast(a >> 8U)); + return res; + #endif } // fp8x4 -> float4 template <> -__inline__ __device__ Float4_ vec_conversion(const uint32_t& a) -{ - Float4_ res; - res.x = vec_conversion((uint16_t)a); - res.y = vec_conversion((uint16_t)(a >> 16U)); - return res; +__inline__ __device__ Float4_ +vec_conversion(const uint32_t& a) { + Float4_ res; + res.x = vec_conversion((uint16_t)a); + res.y = vec_conversion((uint16_t)(a >> 16U)); + return res; } // fp8x8 -> float8 template <> -__inline__ __device__ Float8_ vec_conversion(const uint2& a) -{ - Float4_ tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - Float8_ res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ Float8_ vec_conversion(const uint2& a) { + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } // half -> fp8 template <> -__inline__ __device__ uint8_t vec_conversion(const uint16_t& a) -{ - __half_raw tmp; - tmp.x = a; +__inline__ __device__ uint8_t +vec_conversion(const uint16_t& a) { + __half_raw tmp; + tmp.x = a; - hip_fp8 f8{static_cast(tmp.data)}; - return f8.data; + hip_fp8 f8{static_cast(tmp.data)}; + return f8.data; } // bf16 -> fp8 template <> -__inline__ __device__ uint8_t vec_conversion(const __nv_bfloat16& a) -{ - hip_fp8 res{__bfloat162float(a)}; - return res.data; +__inline__ __device__ uint8_t +vec_conversion(const __nv_bfloat16& a) { + hip_fp8 res{__bfloat162float(a)}; + return res.data; } // float -> fp8 template <> -__inline__ __device__ uint8_t vec_conversion(const float& a) -{ - hip_fp8 f8(a); - return f8.data; +__inline__ __device__ uint8_t vec_conversion(const float& a) { + hip_fp8 f8(a); + return f8.data; } // fp8x4 -> float4 template <> -__inline__ __device__ float4 vec_conversion(const uint32_t& a) -{ - Float4_ tmp = vec_conversion(a); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; +__inline__ __device__ float4 +vec_conversion(const uint32_t& a) { + Float4_ tmp = vec_conversion(a); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; } // float2 -> half2 template <> -__inline__ __device__ uint32_t vec_conversion(const float2& a) -{ - union { - half2 float16; - uint32_t uint32; - }; +__inline__ __device__ uint32_t +vec_conversion(const float2& a) { + union { + half2 float16; + uint32_t uint32; + }; - float16 = __float22half2_rn(a); - return uint32; + float16 = __float22half2_rn(a); + return uint32; } // Float4 -> half2x2 template <> -__inline__ __device__ uint2 vec_conversion(const Float4_& a) -{ - uint2 b; - float2 val; - val.x = a.x.x; - val.y = a.x.y; - b.x = vec_conversion(val); +__inline__ __device__ uint2 vec_conversion(const Float4_& a) { + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val); - val.x = a.y.x; - val.y = a.y.y; - b.y = vec_conversion(val); - return b; + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val); + return b; } // Float4 -> float4 template <> -__inline__ __device__ float4 vec_conversion(const Float4_& a) -{ - float4 b; - b.x = a.x.x; - b.y = a.x.y; - b.z = a.y.x; - b.w = a.y.y; - return b; +__inline__ __device__ float4 vec_conversion(const Float4_& a) { + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; } // Float8 -> half2x4 template <> -__inline__ __device__ uint4 vec_conversion(const Float8_& a) -{ - uint4 b; - b.x = vec_conversion(a.x); - b.y = vec_conversion(a.y); - b.z = vec_conversion(a.z); - b.w = vec_conversion(a.w); - return b; +__inline__ __device__ uint4 vec_conversion(const Float8_& a) { + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; } // float2 -> bfloat162 template <> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a) -{ - __nv_bfloat162 b = __float22bfloat162_rn(a); - return b; +__inline__ __device__ __nv_bfloat162 +vec_conversion<__nv_bfloat162, float2>(const float2& a) { + __nv_bfloat162 b = __float22bfloat162_rn(a); + return b; } // Float4 -> bfloat162x2 template <> -__inline__ __device__ bf16_4_t vec_conversion(const Float4_& a) -{ - bf16_4_t b; - b.x = __float22bfloat162_rn(a.x); - b.y = __float22bfloat162_rn(a.y); - return b; +__inline__ __device__ bf16_4_t +vec_conversion(const Float4_& a) { + bf16_4_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + return b; } // Float8 -> bfloat162x4 template <> -__inline__ __device__ bf16_8_t vec_conversion(const Float8_& a) -{ - bf16_8_t b; - b.x = __float22bfloat162_rn(a.x); - b.y = __float22bfloat162_rn(a.y); - b.z = __float22bfloat162_rn(a.z); - b.w = __float22bfloat162_rn(a.w); - return b; +__inline__ __device__ bf16_8_t +vec_conversion(const Float8_& a) { + bf16_8_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + b.z = __float22bfloat162_rn(a.z); + b.w = __float22bfloat162_rn(a.w); + return b; } +/* Scaled and vectorized conversions, for data exchange between high and low + precision domains -/* Scaled and vectorized conversions, for data exchange between high and low precision domains - - Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale ) - s.t. - Quantize(HP / scale) => FP8 - Dequant(FP8) * scale => HP + Convention of the scale in API, e.g: FP8_data = Quantization( + High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) * + scale => HP */ // fp8 -> half template <> -__inline__ __device__ uint16_t scaled_vec_conversion(const uint8_t& a, const float scale) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - __half_raw res; - res.data = static_cast(f8) * scale; - return res.x; +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint8_t& a, const float scale) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8) * scale; + return res.x; } // fp8x2 -> half2 template <> -__inline__ __device__ uint32_t scaled_vec_conversion(const uint16_t& a, const float scale) -{ -#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - union { - __half2_raw h2r; - uint32_t ui32; - } tmp; - tmp.h2r.x.data = f2[0] * scale; - tmp.h2r.y.data = f2[1] * scale; - return tmp.ui32; -#else - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - - tmp.u16[0] = scaled_vec_conversion(static_cast(a), scale); - tmp.u16[1] = scaled_vec_conversion(static_cast(a >> 8U), scale); - return tmp.u32; -#endif +__inline__ __device__ uint32_t scaled_vec_conversion( + const uint16_t& a, const float scale) { + #if defined(__HIP__MI300__) && \ + defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0] * scale; + tmp.h2r.y.data = f2[1] * scale; + return tmp.ui32; + #else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = + scaled_vec_conversion(static_cast(a), scale); + tmp.u16[1] = scaled_vec_conversion( + static_cast(a >> 8U), scale); + return tmp.u32; + #endif } // fp8x4 -> half2x2 template <> -__inline__ __device__ uint2 scaled_vec_conversion(const uint32_t& a, const float scale) -{ - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); - tmp.u32[1] = scaled_vec_conversion((uint16_t)(a >> 16U), scale); - return tmp.u32x2; +__inline__ __device__ uint2 +scaled_vec_conversion(const uint32_t& a, const float scale) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); + tmp.u32[1] = + scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return tmp.u32x2; } // fp8x8 -> half2x4 template <> -__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, const float scale) -{ - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = scaled_vec_conversion(a.x, scale); - tmp.u64[1] = scaled_vec_conversion(a.y, scale); - return tmp.u64x2; +__inline__ __device__ uint4 +scaled_vec_conversion(const uint2& a, const float scale) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale); + tmp.u64[1] = scaled_vec_conversion(a.y, scale); + return tmp.u64x2; } using __nv_bfloat16 = __hip_bfloat16; // fp8 -> __nv_bfloat16 template <> -__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - float f{f8}; - return __float2bfloat16(f * scale); +__inline__ __device__ __nv_bfloat16 +scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, + const float scale) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f * scale); } using __nv_bfloat162 = __hip_bfloat162; // fp8x2 -> __nv_bfloat162 template <> -__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale) -{ - __nv_bfloat162 res; - res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); - res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); - return res; +__inline__ __device__ __nv_bfloat162 +scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, + const float scale) { + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); + res.y = + scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); + return res; } // fp8x4 -> bf16_4_t template <> -__inline__ __device__ bf16_4_t scaled_vec_conversion(const uint32_t& a, const float scale) -{ - bf16_4_t res; - res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); - res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale); - return res; +__inline__ __device__ bf16_4_t scaled_vec_conversion( + const uint32_t& a, const float scale) { + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), + scale); + return res; } // fp8x8 -> bf16_8_t template <> -__inline__ __device__ bf16_8_t scaled_vec_conversion(const uint2& a, const float scale) -{ - bf16_4_t tmp1, tmp2; - tmp1 = scaled_vec_conversion(a.x, scale); - tmp2 = scaled_vec_conversion(a.y, scale); - bf16_8_t res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ bf16_8_t +scaled_vec_conversion(const uint2& a, const float scale) { + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } // fp8 -> float template <> -__inline__ __device__ float scaled_vec_conversion(const uint8_t& a, const float scale) -{ - hip_fp8 fp8{a, hip_fp8::from_bits()}; - return static_cast(fp8) * scale; +__inline__ __device__ float scaled_vec_conversion( + const uint8_t& a, const float scale) { + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8) * scale; } // fp8x2 -> float2 template <> -__inline__ __device__ float2 scaled_vec_conversion(const uint16_t& a, const float scale) -{ -#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - float2 res; - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - res.x = f2[0] * scale; - res.y = f2[1] * scale; - return res; -#else - float2 res; - res.x = scaled_vec_conversion(static_cast(a), scale); - res.y = scaled_vec_conversion(static_cast(a >> 8U), scale); - return res; -#endif +__inline__ __device__ float2 +scaled_vec_conversion(const uint16_t& a, const float scale) { + #if defined(__HIP__MI300__) && \ + defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0] * scale; + res.y = f2[1] * scale; + return res; + #else + float2 res; + res.x = scaled_vec_conversion(static_cast(a), scale); + res.y = scaled_vec_conversion(static_cast(a >> 8U), + scale); + return res; + #endif } // fp8x4 -> float4 template <> -__inline__ __device__ Float4_ scaled_vec_conversion(const uint32_t& a, const float scale) -{ - Float4_ res; - res.x = scaled_vec_conversion((uint16_t)a, scale); - res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); - return res; +__inline__ __device__ Float4_ +scaled_vec_conversion(const uint32_t& a, const float scale) { + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return res; } // fp8x8 -> float8 template <> -__inline__ __device__ Float8_ scaled_vec_conversion(const uint2& a, const float scale) -{ - Float4_ tmp1, tmp2; - tmp1 = scaled_vec_conversion(a.x, scale); - tmp2 = scaled_vec_conversion(a.y, scale); - Float8_ res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ Float8_ +scaled_vec_conversion(const uint2& a, const float scale) { + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } - /* Quantize(HP / scale) => FP8 */ // TODO(Hai): vectorized to add // half -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const uint16_t& a, const float scale) -{ - __half_raw tmp; - tmp.x = a; +__inline__ __device__ uint8_t +scaled_vec_conversion(const uint16_t& a, const float scale) { + __half_raw tmp; + tmp.x = a; - hip_fp8 f8{static_cast(tmp.data)/scale}; - return f8.data; + hip_fp8 f8{static_cast(tmp.data) / scale}; + return f8.data; } // bf16 -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const __nv_bfloat16& a, const float scale) -{ - hip_fp8 res{__bfloat162float(a)/scale}; - return res.data; +__inline__ __device__ uint8_t scaled_vec_conversion( + const __nv_bfloat16& a, const float scale) { + hip_fp8 res{__bfloat162float(a) / scale}; + return res.data; } // float -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const float& a, const float scale) -{ - hip_fp8 f8(a/scale); - return f8.data; +__inline__ __device__ uint8_t +scaled_vec_conversion(const float& a, const float scale) { + hip_fp8 f8(a / scale); + return f8.data; } // fp8x4 -> float4 template <> -__inline__ __device__ float4 scaled_vec_conversion(const uint32_t& a, const float scale) -{ - Float4_ tmp = scaled_vec_conversion(a, scale); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; +__inline__ __device__ float4 +scaled_vec_conversion(const uint32_t& a, const float scale) { + Float4_ tmp = scaled_vec_conversion(a, scale); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; } -#endif // ENABLE_FP8 + #endif // ENABLE_FP8 template -__inline__ __device__ Tout convert(const Tin &x) { -#ifdef ENABLE_FP8 +__inline__ __device__ Tout convert(const Tin& x) { + #ifdef ENABLE_FP8 if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return vec_conversion(x); } -#endif + #endif assert(false); } template -__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { -#ifdef ENABLE_FP8 +__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { + #ifdef ENABLE_FP8 if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return scaled_vec_conversion(x, scale); } -#endif + #endif assert(false); } -// The following macro is used to dispatch the conversion function based on the -// data type of the key and value cache. The FN is a macro that calls a function -// with template. -#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ - if (KV_DTYPE == "auto") { \ - if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ - } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ - } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ - } else { \ - TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ - } \ - } else { \ - if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + // The following macro is used to dispatch the conversion function based on + // the data type of the key and value cache. The FN is a macro that calls a + // function with template. + #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ } else { \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ } else { \ - TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ - } \ - } + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } -} // fp8 -#endif // USE_ROCM -} // namespace vllm +} // namespace fp8 +#endif // USE_ROCM +} // namespace vllm diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index b9c5d39277ca5..55be3305a9b8c 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -10,17 +10,20 @@ namespace vllm { __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { - float old; - old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) : - __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + float old; + old = (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float( + atomicMin((unsigned int*)addr, __float_as_uint(value))); - return old; + return old; } #define FP8_E4M3_MAX std::numeric_limits::max() -template -__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar_t val, const float scale) { +template +__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion( + const scalar_t val, const float scale) { float x = static_cast(val) / scale; float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); return static_cast(r); @@ -32,11 +35,10 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar // So to get the right answer, *scale needs to be initialized to // a value <= 0.0 and we need to wait for all thread blocks to // finish before consuming *scale. -template -__global__ void segmented_max_reduction( - float* __restrict__ scale, - const scalar_t* __restrict__ input, - int64_t num_elems) { +template +__global__ void segmented_max_reduction(float* __restrict__ scale, + const scalar_t* __restrict__ input, + int64_t num_elems) { __shared__ float cache[1024]; int i = blockDim.x * blockIdx.x + threadIdx.x; @@ -56,7 +58,7 @@ __global__ void segmented_max_reduction( int ib = blockDim.x / 2; while (ib != 0) { if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { - cache[threadIdx.x] = cache[threadIdx.x + ib]; + cache[threadIdx.x] = cache[threadIdx.x + ib]; } __syncthreads(); ib /= 2; @@ -64,16 +66,16 @@ __global__ void segmented_max_reduction( // Finally, since cache[0] contains the maximum for this thread block, // atomically write the max to the target location if (threadIdx.x == 0) { - atomicMaxFloat(scale, cache[0] / std::numeric_limits::max()); + atomicMaxFloat(scale, + cache[0] / std::numeric_limits::max()); } } -template -__global__ void scaled_fp8_quant_kernel( - c10::Float8_e4m3fn* __restrict__ out, - const scalar_t* __restrict__ input, - const float* __restrict__ scale, - int64_t num_elems) { +template +__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, + const scalar_t* __restrict__ input, + const float* __restrict__ scale, + int64_t num_elems) { int i = blockDim.x * blockIdx.x + threadIdx.x; while (i < num_elems) { out[i] = scaled_fp8_conversion(input[i], *scale); @@ -81,12 +83,11 @@ __global__ void scaled_fp8_quant_kernel( } } -} // namespace vllm +} // namespace vllm -void static_scaled_fp8_quant( - torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., d] - torch::Tensor& scale) // [1] +void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., d] + torch::Tensor& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); @@ -95,21 +96,16 @@ void static_scaled_fp8_quant( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "scaled_fp8_quant_kernel", - [&] { - vllm::scaled_fp8_quant_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - scale.data_ptr(), - num_elems); + input.scalar_type(), "scaled_fp8_quant_kernel", [&] { + vllm::scaled_fp8_quant_kernel<<>>( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), num_elems); }); } -void dynamic_scaled_fp8_quant( - torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., d] - torch::Tensor& scale) // [1] +void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., d] + torch::Tensor& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); @@ -118,18 +114,11 @@ void dynamic_scaled_fp8_quant( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "scaled_fp8_quant_kernel", - [&] { - vllm::segmented_max_reduction<<>>( - scale.data_ptr(), - input.data_ptr(), - num_elems); - vllm::scaled_fp8_quant_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - scale.data_ptr(), - num_elems); + input.scalar_type(), "scaled_fp8_quant_kernel", [&] { + vllm::segmented_max_reduction<<>>( + scale.data_ptr(), input.data_ptr(), num_elems); + vllm::scaled_fp8_quant_kernel<<>>( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), num_elems); }); } - diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/fp8/nvidia/quant_utils.cuh index 4eeacf7a6f9d9..cde26dbda18cf 100644 --- a/csrc/quantization/fp8/nvidia/quant_utils.cuh +++ b/csrc/quantization/fp8/nvidia/quant_utils.cuh @@ -10,9 +10,9 @@ namespace vllm { #ifndef USE_ROCM namespace fp8 { -#ifdef ENABLE_FP8 + #ifdef ENABLE_FP8 -#if 0 // Disable the following code to reduce the binary size. + #if 0 // Disable the following code to reduce the binary size. template __inline__ __device__ Tout vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { @@ -177,13 +177,13 @@ __inline__ __device__ uint8_t vec_conversion( template <> __inline__ __device__ uint8_t vec_conversion( const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); -#else + #else __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8( __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type); return (uint8_t)res; -#endif + #endif } // float -> fp8 @@ -276,7 +276,7 @@ __inline__ __device__ bf16_8_t vec_conversion( from_float(b, a); return b; } -#endif + #endif /* Scaled and vectorized conversions, for data exchange between high and low precision domains Convention of the scale in API, e.g: FP8_data = @@ -286,14 +286,14 @@ __inline__ __device__ bf16_8_t vec_conversion( template __inline__ __device__ Tout scaled_vec_conversion( - const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) { + const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) { return x; } // fp8 -> half template <> __inline__ __device__ uint16_t scaled_vec_conversion( - const uint8_t &a, const float scale, + const uint8_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type); return float_to_half(half_to_float(tmp.x) * scale); @@ -302,7 +302,7 @@ __inline__ __device__ uint16_t scaled_vec_conversion( // fp8x2 -> half2 template <> __inline__ __device__ uint32_t scaled_vec_conversion( - const uint16_t &a, const float scale, + const uint16_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { union { uint16_t u16[2]; @@ -317,7 +317,7 @@ __inline__ __device__ uint32_t scaled_vec_conversion( // fp8x4 -> half2x2 template <> __inline__ __device__ uint2 scaled_vec_conversion( - const uint32_t &a, const float scale, + const uint32_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { union { uint2 u32x2; @@ -333,7 +333,7 @@ __inline__ __device__ uint2 scaled_vec_conversion( // fp8x8 -> half2x4 template <> __inline__ __device__ uint4 -scaled_vec_conversion(const uint2 &a, const float scale, +scaled_vec_conversion(const uint2& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { union { uint4 u64x2; @@ -348,7 +348,7 @@ scaled_vec_conversion(const uint2 &a, const float scale, template <> __inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>( - const uint8_t &a, const float scale, + const uint8_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { // Note there is no direct convert function from fp8 to bf16. // fp8 -> half @@ -362,7 +362,7 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>( template <> __inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>( - const uint16_t &a, const float scale, + const uint16_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __nv_bfloat162 res; res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, @@ -375,7 +375,7 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>( // fp8x4 -> bf16_4_t template <> __inline__ __device__ bf16_4_t scaled_vec_conversion( - const uint32_t &a, const float scale, + const uint32_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { bf16_4_t res; res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, @@ -388,7 +388,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion( // fp8x8 -> bf16_8_t template <> __inline__ __device__ bf16_8_t scaled_vec_conversion( - const uint2 &a, const float scale, + const uint2& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { bf16_4_t tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); @@ -404,9 +404,8 @@ __inline__ __device__ bf16_8_t scaled_vec_conversion( // fp8 -> float template <> __inline__ __device__ float scaled_vec_conversion( - const uint8_t &a, const float scale, + const uint8_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { - // fp8 -> half __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); uint16_t tmp = res.x; @@ -418,7 +417,7 @@ __inline__ __device__ float scaled_vec_conversion( // fp8x2 -> float2 template <> __inline__ __device__ float2 scaled_vec_conversion( - const uint16_t &a, const float scale, + const uint16_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { // fp8x2 -> half2 uint32_t tmp = scaled_vec_conversion(a, scale, fp8_type); @@ -429,7 +428,7 @@ __inline__ __device__ float2 scaled_vec_conversion( // fp8x4 -> float4 template <> __inline__ __device__ Float4_ scaled_vec_conversion( - const uint32_t &a, const float scale, + const uint32_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { Float4_ res; res.x = scaled_vec_conversion((uint16_t)a, scale, fp8_type); @@ -441,7 +440,7 @@ __inline__ __device__ Float4_ scaled_vec_conversion( // fp8x8 -> float8 template <> __inline__ __device__ Float8_ scaled_vec_conversion( - const uint2 &a, const float scale, + const uint2& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { Float4_ tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); @@ -457,7 +456,7 @@ __inline__ __device__ Float8_ scaled_vec_conversion( // half -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( - const uint16_t &a, const float scale, + const uint16_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type); @@ -467,21 +466,21 @@ __inline__ __device__ uint8_t scaled_vec_conversion( // bf16 -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( - const __nv_bfloat16 &a, const float scale, + const __nv_bfloat16& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); -#else + #else __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale, __NV_SATFINITE, fp8_type); return (uint8_t)res; -#endif + #endif } // float -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( - const float &a, const float scale, + const float& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type); @@ -491,78 +490,81 @@ __inline__ __device__ uint8_t scaled_vec_conversion( // fp8x4 -> float4 template <> __inline__ __device__ float4 scaled_vec_conversion( - const uint32_t &a, const float scale, + const uint32_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { Float4_ tmp = scaled_vec_conversion(a, scale, fp8_type); float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); return res; } -#endif // ENABLE_FP8 + #endif // ENABLE_FP8 template -__inline__ __device__ Tout convert(const Tin &x) { -#if 0 // Disable the following code to reduce the binary size. +__inline__ __device__ Tout convert(const Tin& x) { + #if 0 // Disable the following code to reduce the binary size. if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return vec_conversion(x, __NV_E4M3); } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { return vec_conversion(x, __NV_E5M2); } -#endif + #endif assert(false); } template -__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { -#ifdef ENABLE_FP8 +__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { + #ifdef ENABLE_FP8 if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return scaled_vec_conversion(x, scale, __NV_E4M3); } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { return scaled_vec_conversion(x, scale, __NV_E5M2); } -#endif + #endif assert(false); } -// The following macro is used to dispatch the conversion function based on the -// data type of the key and value cache. The FN is a macro that calls a function -// with template. -#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ - if (KV_DTYPE == "auto") { \ - if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ - } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ - } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ - } else { \ - TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ - } \ - } else { \ - if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + // The following macro is used to dispatch the conversion function based on + // the data type of the key and value cache. The FN is a macro that calls a + // function with template. + #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ } else { \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ - } else if (KV_DTYPE == "fp8_e5m2") { \ - if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ - } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ - } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else if (KV_DTYPE == "fp8_e5m2") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ } else { \ - TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ } \ - } else { \ - TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ - } \ - } + } -} // namespace fp8 -#endif // not USE_ROCM -} // namespace vllm +} // namespace fp8 +#endif // not USE_ROCM +} // namespace vllm diff --git a/csrc/quantization/gptq/compat.cuh b/csrc/quantization/gptq/compat.cuh index 4da0bc6e2df38..1b3fb3d39103f 100644 --- a/csrc/quantization/gptq/compat.cuh +++ b/csrc/quantization/gptq/compat.cuh @@ -9,54 +9,54 @@ namespace vllm { namespace gptq { // atomicAdd for half types, to support CC < 7.x -__device__ __forceinline__ void atomicAdd_half(half* address, half val) -{ - unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); - unsigned int old = *address_as_ui; - unsigned int assumed; +__device__ __forceinline__ void atomicAdd_half(half* address, half val) { + unsigned int* address_as_ui = + (unsigned int*)((char*)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; - do - { - assumed = old; - __half_raw hsum; - hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); - half tmpres = __hadd(hsum, val); - hsum = __half_raw(tmpres); - old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; - old = atomicCAS(address_as_ui, assumed, old); - } - while (assumed != old); + do { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) + : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); } // atomicAdd for half2 types -__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) -{ - unsigned int* address_as_ui = (unsigned int*)address; - unsigned int old = *address_as_ui; - unsigned int assumed; - do - { - assumed = old; - half2 old_val = *((half2*)&old); - half2 new_val = __hadd2(old_val, val); - old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); - } - while (assumed != old); +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) { + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } while (assumed != old); } // #if defined(__CUDA_ARCH__) || defined(USE_ROCM) -#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + #if __CUDA_ARCH__ < 700 || defined(USE_ROCM) -__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } +__device__ __forceinline__ void atomicAdd(half* address, half val) { + atomicAdd_half(address, val); +} -#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) -__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } -#endif + #if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { + atomicAdd_half2(address, val); +} + #endif -#endif + #endif #endif } // namespace gptq diff --git a/csrc/quantization/gptq/matrix_view.cuh b/csrc/quantization/gptq/matrix_view.cuh index eda3436eb5375..2b6719fbdc1bc 100644 --- a/csrc/quantization/gptq/matrix_view.cuh +++ b/csrc/quantization/gptq/matrix_view.cuh @@ -1,5 +1,6 @@ /* -Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama +Adapted from https://github.com/turboderp/exllamav2 and +https://github.com/turboderp/exllama */ #ifndef _matrix_view_cuh @@ -13,260 +14,280 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo namespace vllm { namespace gptq { -class MatrixView_half -{ -public: - const half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } - __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } - - __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const - { - half2* ptr = (half2*) item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __low2half(i01); - items[1] = __high2half(i01); - items[2] = __low2half(i23); - items[3] = __high2half(i23); - } - __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const - { - half2* ptr = (half2*)item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __half2float(__low2half(i01)); - items[1] = __half2float(__high2half(i01)); - items[2] = __half2float(__low2half(i23)); - items[3] = __half2float(__high2half(i23)); - } - - __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const - { - half2* ptr = (half2*)item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __half2half2(__low2half(i01)); - items[1] = __half2half2(__high2half(i01)); - items[2] = __half2half2(__low2half(i23)); - items[3] = __half2half2(__high2half(i23)); - } +class MatrixView_half { + public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2*)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { + return __half2half2(data[row * width + column]); + } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { + return &data[row * width + column]; + } + + __device__ __forceinline__ void item4(half (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } }; -class MatrixView_half_rw -{ -public: - half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } - __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } - __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } - - __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) - { - half2 v01 = __halves2half2(v0, v1); - half2 v23 = __halves2half2(v2, v3); - half2* ptr = (half2*) item_ptr(row, column); - ptr[0] = v01; - ptr[1] = v23; - } +class MatrixView_half_rw { + public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2*)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half* item_ptr(int row, int column) { + return &data[row * width + column]; + } + __device__ __forceinline__ void set(int row, int column, half value) { + data[row * width + column] = value; + } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { + ((half2*)data)[(row * width + column) / 2] = value; + } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, + half v2, half v3) { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2* ptr = (half2*)item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } }; -class MatrixView_q4_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x07) * 4; - return (data[row * width / 8 + column / 8] >> shift) & 0x0f; - } - - __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const - { - int shift = (column & 0x07) * 4; - uint32_t d = data[row * width / 8 + column / 8] >> shift; - items[0] = d & 0x0f; - items[1] = (d >> 4) & 0x0f; - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x07) * 4; - uint32_t d = data[row * width / 8 + column / 8] >> shift; - items[0] = d & 0x0f; - items[1] = (d >> 4) & 0x0f; - items[2] = (d >> 8) & 0x0f; - items[3] = (d >> 12) & 0x0f; - } +class MatrixView_q4_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } }; -class MatrixView_q4_column -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (row & 0x07) * 4; - return (data[row / 8 * width + column] >> shift) & 0x0f; - } - - __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } - __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +class MatrixView_q4_column { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { + return data[row / 8 * width + column]; + } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, + int column) { + return &data[row / 8 * width + column]; + } }; -class MatrixView_q2_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x0f) * 2; - return (data[row * width / 16 + column / 16] >> shift) & 0x03; - } - - __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const - { - int shift = (column & 0x0f) * 2; - uint32_t d = data[row * width / 16 + column / 16] >> shift; - items[0] = d & 0x03; - items[1] = (d >> 2) & 0x03; - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x0f) * 2; - uint32_t d = data[row * width / 16 + column / 16] >> shift; - items[0] = d & 0x03; - items[1] = (d >> 2) & 0x03; - items[2] = (d >> 4) & 0x03; - items[3] = (d >> 6) & 0x03; - } +class MatrixView_q2_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x0f) * 2; + return (data[row * width / 16 + column / 16] >> shift) & 0x03; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + items[2] = (d >> 4) & 0x03; + items[3] = (d >> 6) & 0x03; + } }; -class MatrixView_q3_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int z_w = column * 3 / 32; - int z_mod = column & 0x1f; - - if (z_mod == 10) { - return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); - } else if (z_mod == 21) { - return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); - } else if (z_mod < 10) { - return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; - } else if (z_mod < 21) { - return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; - } else { - return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; - } +class MatrixView_q3_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int z_w = column * 3 / 32; + int z_mod = column & 0x1f; + + if (z_mod == 10) { + return (data[row * width * 3 / 32 + z_w] >> 30) | + ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); + } else if (z_mod == 21) { + return (data[row * width * 3 / 32 + z_w] >> 31) | + ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); + } else if (z_mod < 10) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; + } else if (z_mod < 21) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; + } else { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x1f); - uint32_t d; - if (shift <= 4) { - d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); - } else if (shift == 8) { - d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); - } else if (shift <= 16) { - d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); - } else if (shift == 20) { - d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); - } else { - d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); - } - items[0] = d & 0x07; - items[1] = (d >> 3) & 0x07; - items[2] = (d >> 6) & 0x07; - items[3] = (d >> 9) & 0x07; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x1f); + uint32_t d; + if (shift <= 4) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); + } else if (shift == 8) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | + ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); + } else if (shift <= 16) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); + } else if (shift == 20) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | + ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); + } else { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); } + items[0] = d & 0x07; + items[1] = (d >> 3) & 0x07; + items[2] = (d >> 6) & 0x07; + items[3] = (d >> 9) & 0x07; + } }; -class MatrixView_q8_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x03) * 8; - return (data[row * width / 4 + column / 4] >> shift) & 0xff; - } - - __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const - { - int shift = (column & 0x03) * 8; - uint32_t d = data[row * width / 4 + column / 4] >> shift; - items[0] = d & 0xff; - items[1] = (d >> 8) & 0xff; - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x03) * 2; - uint32_t d = data[row * width / 4 + column / 4] >> shift; - items[0] = d & 0xff; - items[1] = (d >> 8) & 0xff; - items[2] = (d >> 16) & 0xff; - items[3] = (d >> 24) & 0xff; - } +class MatrixView_q8_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x03) * 8; + return (data[row * width / 4 + column / 4] >> shift) & 0xff; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x03) * 8; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x03) * 2; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + items[2] = (d >> 16) & 0xff; + items[3] = (d >> 24) & 0xff; + } }; } // namespace gptq diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index cc56649917a8a..480c4986c3821 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -1,5 +1,6 @@ /* -Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopqwop200/GPTQ-for-LLaMa +Adapted from https://github.com/turboderp/exllamav2 and +https://github.com/qwopqwop200/GPTQ-for-LLaMa */ #include @@ -32,2044 +33,1824 @@ namespace gptq { #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) #if defined(USE_ROCM) -#include -__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, - hipblasOperation_t transA, - hipblasOperation_t transB, - int m, - int n, - int k, - const half* alpha, - const half* AP, - int lda, - const half* BP, - int ldb, - const half* beta, - half* CP, - int ldc) { - return hipblasHgemm(handle, transA, transB, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(AP), lda, - reinterpret_cast(BP), ldb, - reinterpret_cast(beta), - reinterpret_cast(CP), ldc); + #include +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm( + hipblasHandle_t handle, hipblasOperation_t transA, + hipblasOperation_t transB, int m, int n, int k, const half* alpha, + const half* AP, int lda, const half* BP, int ldb, const half* beta, + half* CP, int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); } -#define hipblasHgemm __compat_hipblasHgemm + #define hipblasHgemm __compat_hipblasHgemm -// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. -#define rocblas_operation_none HIPBLAS_OP_N -#define rocblas_hgemm __compat_hipblasHgemm + // Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. + #define rocblas_operation_none HIPBLAS_OP_N + #define rocblas_hgemm __compat_hipblasHgemm #endif -__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __hadd2(result, g_result); +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr, + const half2 g_result) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hadd2(result, g_result); } -__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __half2float(__low2half(result)) + __half2float(__high2half(result)); +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __half2float(__low2half(result)) + __half2float(__high2half(result)); } -__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); } -__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +__forceinline__ __device__ half2 dot22_16(half2 (&dq)[8], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); } -__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); - return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +__forceinline__ __device__ half2 dot22_32(half2 (&dq)[16], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); } -__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); - return fma(result_f, qs_f, g_result); +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); } -__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); - float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); - return fma(result_f, qs_f, g_result); +__forceinline__ __device__ float dot22_16_f(half2 (&dq)[8], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); } -__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); - float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); - return fma(result_f, qs_f, g_result); +__forceinline__ __device__ float dot22_32_f(half2 (&dq)[16], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); } -__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h) -{ - // Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127 - - float result = {}; - #pragma unroll - for (int i = 0; i < 4; i++) - { - half2 w01 = dq[i]; - float w0 = __low2float(w01); - float w1 = __high2float(w01); - float x0 = __half2float(*a_ptr++); - float x1 = __half2float(*a_ptr++); - result = fma(w0, x0, result); - result = fma(w1, x1, result); - } - float qs = __half2float(qs_h); - result *= qs; - half result_h = __float2half_rn(result); - return __hadd(result_h, g_result); +__forceinline__ __device__ half dot22_8_h(half2 (&dq)[4], const half* a_ptr, + const half g_result, + const half qs_h) { + // Use FP32 accumulator to avoid potential overflow since unscaled weights are + // in the range -128..127 + + float result = {}; +#pragma unroll + for (int i = 0; i < 4; i++) { + half2 w01 = dq[i]; + float w0 = __low2float(w01); + float w1 = __high2float(w01); + float x0 = __half2float(*a_ptr++); + float x1 = __half2float(*a_ptr++); + result = fma(w0, x0, result); + result = fma(w1, x1, result); + } + float qs = __half2float(qs_h); + result *= qs; + half result_h = __float2half_rn(result); + return __hadd(result_h, g_result); } -__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); - half result_h = __hadd(__low2half(result), __high2half(result)); - return __hfma(result_h, qs_h, g_result); +__forceinline__ __device__ half dot22_16_h(half2 (&dq)[8], const half* a_ptr, + const half g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); } -__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); - half result_h = __hadd(__low2half(result), __high2half(result)); - return __hfma(result_h, qs_h, g_result); +__forceinline__ __device__ half dot22_32_h(half2 (&dq)[16], const half* a_ptr, + const half g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); } - -typedef void (*fp_gemm_half_q_half_gptq_kernel) -( - const half*, - const uint32_t*, - const uint32_t*, - const half*, - half*, - const int, - const int, - const int, - const int, - const int* -); - +typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half*, const uint32_t*, + const uint32_t*, const half*, + half*, const int, const int, + const int, const int, + const int*); template -__global__ void gemm_half_q_half_gptq_4bit_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, +__global__ void gemm_half_q_half_gptq_4bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int* __restrict__ b_q_perm -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; } + } - // Zero output - if (n >= size_n) return; + // Zero output + if (n >= size_n) return; - if (blockIdx.z == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + float scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + // Column result + float block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); } - __syncthreads(); - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - int qk = offset_k / (32 / 4); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - int zeros[4]; - float scales[4]; - half2 z1z16[4][2]; - half2 y1y16[4][2]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_f(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); - - // Column result - float block_c[m_count][4] = {}; - - // Dequantize and multiply - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_f(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); - } - - #pragma unroll - for (int j = 0; j < 4; j++) - { - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - half2 dq[4][4]; - dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); - dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); - dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); - dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); - - #pragma unroll - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); - block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); - block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); - block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); - } - - b_ptr += size_n; - a_ptr += 8; - } - - k += 32; +#pragma unroll + for (int j = 0; j < 4; j++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, + false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, + false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, + false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, + false); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], + block_c[m][0]); + block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], + block_c[m][1]); + block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], + block_c[m][2]); + block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], + block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; } - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); - half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); - } + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), + __float2half_rn(block_c[m][1])); + half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), + __float2half_rn(block_c[m][3])); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } } template -__global__ void gemm_half_q_half_gptq_2bit_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, +__global__ void gemm_half_q_half_gptq_2bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int* __restrict__ b_q_perm -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; } + } - // Zero output - if (n >= size_n) return; + // Zero output + if (n >= size_n) return; - if (blockIdx.z == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 2); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); } - __syncthreads(); - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - int qk = offset_k / (32 / 2); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - int zeros[4]; - half scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - // Column result - half block_c[m_count][4] = {}; - - // Dequantize and multiply - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - } - - #pragma unroll - for (int j = 0; j < 1; j++) - { - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - half2 dq[4][8]; - dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); - dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); - dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); - dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); - - #pragma unroll - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); - block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); - block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); - block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); - } - - b_ptr += size_n; - a_ptr += 16; - } - - k += 16; +#pragma unroll + for (int j = 0; j < 1; j++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = + dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = + dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = + dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + + b_ptr += size_n; + a_ptr += 16; } - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); - half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); - } + k += 16; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } } template -__global__ void gemm_half_q_half_gptq_3bit_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, +__global__ void gemm_half_q_half_gptq_3bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int* __restrict__ b_q_perm -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; } + } - // Zero output - if (n >= size_n) return; + // Zero output + if (n >= size_n) return; - if (blockIdx.z == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / 32 * 3; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); } - __syncthreads(); - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - int qk = offset_k / 32 * 3; - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - int zeros[4]; - half scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - // Column result - half block_c[m_count][4] = {}; - - // Dequantize and multiply - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - } - - #pragma unroll - for (int j = 0; j < 1; j++) - { - int4 load_int4[3]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][16]; - dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1); - dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1); - dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1); - dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1); - - #pragma unroll - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); - block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); - block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); - block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); - } - a_ptr += 32; - } - - k += 32; +#pragma unroll + for (int j = 0; j < 1; j++) { + int4 load_int4[3]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], + size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], + size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], + size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], + size_n, zeros[3] + 1); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = + dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = + dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = + dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 32; } - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); - half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); - } + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } } template -__global__ void gemm_half_q_half_gptq_8bit_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, +__global__ void gemm_half_q_half_gptq_8bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int* __restrict__ b_q_perm -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } - } - - // Zero output - if (n >= size_n) return; - - if (blockIdx.z == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; } + } - __syncthreads(); - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - int qk = offset_k / (32 / 8); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - int zeros[4]; - half scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - // Column result - half block_c[m_count][4] = {}; - - // Dequantize and multiply - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - } + // Zero output + if (n >= size_n) return; - #pragma unroll - for (int j = 0; j < 4; j++) - { - int4 load_int4[2]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][4]; - dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1); - dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1); - dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1); - dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1); - - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); - block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); - block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); - block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); - } - a_ptr += 8; - } - k += 32; + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 8); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); } - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); - half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); +#pragma unroll + for (int j = 0; j < 4; j++) { + int4 load_int4[2]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, + zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, + zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, + zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, + zeros[3] + 1); + + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = + dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = + dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = + dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 8; } + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } } fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel( - bool first_block, const int m_count, const int bit) -{ - #define SELECT_KERNEL(M_COUNT) \ - if (m_count == M_COUNT) { \ - if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel; \ - if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel; \ - if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel; \ - if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel; \ - } - #if BLOCK_M_SIZE_MAX >= 1 - SELECT_KERNEL(1); - #endif - #if BLOCK_M_SIZE_MAX >= 2 - SELECT_KERNEL(2); - #endif - #if BLOCK_M_SIZE_MAX >= 3 - SELECT_KERNEL(3); - #endif - #if BLOCK_M_SIZE_MAX >= 4 - SELECT_KERNEL(4); - #endif - #if BLOCK_M_SIZE_MAX >= 5 - SELECT_KERNEL(5); - #endif - #if BLOCK_M_SIZE_MAX >= 6 - SELECT_KERNEL(6); - #endif - #if BLOCK_M_SIZE_MAX >= 7 - SELECT_KERNEL(7); - #endif - #if BLOCK_M_SIZE_MAX >= 8 - SELECT_KERNEL(8); - #endif - return NULL; + bool first_block, const int m_count, const int bit) { +#define SELECT_KERNEL(M_COUNT) \ + if (m_count == M_COUNT) { \ + if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel; \ + if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel; \ + if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel; \ + if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel; \ + } +#if BLOCK_M_SIZE_MAX >= 1 + SELECT_KERNEL(1); +#endif +#if BLOCK_M_SIZE_MAX >= 2 + SELECT_KERNEL(2); +#endif +#if BLOCK_M_SIZE_MAX >= 3 + SELECT_KERNEL(3); +#endif +#if BLOCK_M_SIZE_MAX >= 4 + SELECT_KERNEL(4); +#endif +#if BLOCK_M_SIZE_MAX >= 5 + SELECT_KERNEL(5); +#endif +#if BLOCK_M_SIZE_MAX >= 6 + SELECT_KERNEL(6); +#endif +#if BLOCK_M_SIZE_MAX >= 7 + SELECT_KERNEL(7); +#endif +#if BLOCK_M_SIZE_MAX >= 8 + SELECT_KERNEL(8); +#endif + return NULL; } +void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_q_perm, + half* c, int size_m, int size_n, int size_k, + int m_count, int groups, int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + fp_gemm_half_q_half_gptq_kernel kernel = + pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>(a, b_q_weight, b_gptq_qzeros, + b_gptq_scales, c, size_m, size_n, + size_k, groups, b_q_perm); +} -void gemm_half_q_half_cuda_part -( - const half* a, - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_q_perm, - half* c, - int size_m, - int size_n, - int size_k, - int m_count, - int groups, - int bit -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - blockDim.z = 1; - gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); - gridDim.y = DIVIDE(size_m, m_count); - gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); +__global__ void reconstruct_exllama_8bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>> - ( - a, - b_q_weight, - b_gptq_qzeros, - b_gptq_scales, - c, - size_m, - size_n, - size_k, - groups, - b_q_perm - ); -} + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; -__global__ void reconstruct_exllama_8bit_kernel -( - const uint32_t* __restrict__ b_q_weight, - const int* __restrict__ b_q_perm, - const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - __shared__ int perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; - } + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } - // Column - int n = offset_n + t * 4; - if (n >= size_n) return; + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; - // b offset - int qk = offset_k / (32 / 8); + // b offset + int qk = offset_k / (32 / 8); - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - // Initial zeros/scale - int zeros[4]; - half2 scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); - __syncthreads(); + __syncthreads(); - int k = offset_k; - int lk = 0; + int k = offset_k; + int lk = 0; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - } + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } - for (int p = 0; p < 4; p++) - { - int4 load_int4[2]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][4]; - dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1); - dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1); - dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1); - dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1); - - //half* dqh = (half*)dq; - if (b_q_perm) - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } + for (int p = 0; p < 4; p++) { + int4 load_int4[2]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, + zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, + zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, + zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, + zeros[3] + 1); + + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); } - k += 32; + } else { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } } + k += 32; + } } -__global__ void reconstruct_exllama_4bit_kernel -( - const uint32_t* __restrict__ b_q_weight, - const int* __restrict__ b_q_perm, +__global__ void reconstruct_exllama_4bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - __shared__ int perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); } - // Column - int n = offset_n + t * 4; - if (n >= size_n) return; - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // b offset - int qk = offset_k / (32 / 4); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - - // Initial zeros/scale - int zeros[4]; - half2 scales[4]; - half2 z1z16[4][2]; - half2 y1y16[4][2]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); - - __syncthreads(); - - int k = offset_k; - int lk = 0; - - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + for (int p = 0; p < 4; p++) { + half2 dq[4][4]; + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, + false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, + false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, + false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, + false); + + b_ptr += size_n; + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); } - - for (int p = 0; p < 4; p++) - { - half2 dq[4][4]; - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); - dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); - dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); - dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); - - b_ptr += size_n; - //half* dqh = (half*)dq; - if (b_q_perm) - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } + } else { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); } - k += 32; + } } + k += 32; + } } -__global__ void reconstruct_exllama_3bit_kernel -( - const uint32_t* __restrict__ b_q_weight, - const int* __restrict__ b_q_perm, +__global__ void reconstruct_exllama_3bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - __shared__ int perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; - } + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - // Column - int n = offset_n + t * 4; - if (n >= size_n) return; + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - // b offset - int qk = offset_k / 32* 3; + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } - // Initial zeros/scale - int zeros[4]; - half2 scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; - __syncthreads(); + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; - int k = offset_k; - int lk = 0; + // b offset + int qk = offset_k / 32 * 3; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - } + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); - for (int p = 0; p < 1; p++) - { - int4 load_int4[3]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][16]; - dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1); - dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1); - dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1); - dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1); - - if (b_q_perm) - { - for (int j = 0; j < 16; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 16; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 1; p++) { + int4 load_int4[3]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], + size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], + size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], + size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], + size_n, zeros[3] + 1); + + if (b_q_perm) { + for (int j = 0; j < 16; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 16; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); } - k += 32; + } } + k += 32; + } } -__global__ void reconstruct_exllama_2bit_kernel -( - const uint32_t* __restrict__ b_q_weight, - const int* __restrict__ b_q_perm, +__global__ void reconstruct_exllama_2bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - __shared__ int perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; - } + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - // Column - int n = offset_n + t * 4; - if (n >= size_n) return; + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - // b offset - int qk = offset_k / (32 / 2); + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } - // Initial zeros/scale - int zeros[4]; - half2 scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; - __syncthreads(); + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; - int k = offset_k; - int lk = 0; + // b offset + int qk = offset_k / (32 / 2); - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - } + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - for (int p = 0; p < 2; p++) - { - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - half2 dq[4][8]; - dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); - dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); - dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); - dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); - - b_ptr += size_n; - //half* dqh = (half*)dq; - if (b_q_perm) - { - for (int j = 0; j < 8; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 8; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - } - k += 32; - } -} + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); -void reconstruct_exllama -( - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_q_perm, - half* out, - int height, - int width, - int groups, - int bit -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); - gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + __syncthreads(); - auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel; - if (bit == 2) { - reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel; - } else if (bit == 3) { - reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel; - } else if (bit == 8) { - reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - reconstruct_exllama_kernel<<>> - ( - b_q_weight, - b_q_perm, - b_gptq_qzeros, - b_gptq_scales, - height, - width, - groups, - out - ); + for (int p = 0; p < 2; p++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + + b_ptr += size_n; + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 8; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 8; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } } +void reconstruct_exllama(const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_q_perm, + half* out, int height, int width, int groups, + int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel; + if (bit == 2) { + reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel; + } else if (bit == 3) { + reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel; + } else if (bit == 8) { + reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + reconstruct_exllama_kernel<<>>( + b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups, + out); +} __global__ void gemm_half_q_half_alt_4bit_kernel( - const half2* __restrict__ vec, - const uint32_t* __restrict__ mat, - half* __restrict__ mul, - const half* __restrict__ scales, - const uint32_t* __restrict__ zeros, - const int* __restrict__ g_idx, - int batch, - int height, - int width -) -{ - int zero_width = width / 8; - int vec_height = height * 4; - const int blockwidth2 = BLOCK_KN_SIZE / 2; - int b = blockIdx.y * BLOCK_M_SIZE_MAX; - int b_end = min(BLOCK_M_SIZE_MAX, batch - b); - int h = BLOCK_KN_SIZE * blockIdx.z / 8; - int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; - int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - - __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; - if (threadIdx.x < h_end) { - for (int m = 0; m < b_end; ++m) { - blockvec[m][threadIdx.x] = - vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + - threadIdx.x]; - } + const half2* __restrict__ vec, const uint32_t* __restrict__ mat, + half* __restrict__ mul, const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, + int batch, int height, int width) { + int zero_width = width / 8; + int vec_height = height * 4; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int h = BLOCK_KN_SIZE * blockIdx.z / 8; + int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; } - - __shared__ half2 deq2[256][8]; - int val = threadIdx.x / 8; - int off = threadIdx.x % 8; - for (; val < 256; val += BLOCK_KN_SIZE / 8) { - deq2[val][off] = __halves2half2( - __int2half_rn(val & 0xF), __int2half_rn(val >> 4) - ); + } + + __shared__ half2 deq2[256][8]; + int val = threadIdx.x / 8; + int off = threadIdx.x % 8; + for (; val < 256; val += BLOCK_KN_SIZE / 8) { + deq2[val][off] = + __halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4)); + } + + if (blockIdx.z == 0) { + for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 8; + int k = 0; + int z_w = w / 8; + int z_mod = (w % 8) * 4; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[4]; + half2 zeros_tmp[4]; + for (int tmp_k = 0; tmp_k < 4; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, + __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - + 1)), + __hmul(scale_f2, + __int2half_rn( + -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; } - - if (blockIdx.z == 0) - { - for (int m = 0; m < b_end; m++) - mul[(b + m) * width + w] = __int2half_rn(0); - } - __syncthreads(); - - int i = width * h + w; - int g_h = h * 8; - int k = 0; - int z_w = w / 8; - int z_mod = (w % 8) * 4; - half2 res2; - half res[BLOCK_M_SIZE_MAX] = {}; - - unsigned int tmp; - while (k < h_end) { - tmp = mat[i]; - half2 scales_tmp[4]; - half2 zeros_tmp[4]; - for (int tmp_k = 0; tmp_k < 4; tmp_k++) { - int g = g_idx[g_h + (k + tmp_k) * 2]; - int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; - half scale_f = scales[g * width + w]; - half scale_f2 = scales[g2 * width + w]; - half2 scale = __halves2half2(scale_f, scale_f2); - half2 zero = __halves2half2( - __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)), - __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1)) - ); - scales_tmp[tmp_k] = scale; - zeros_tmp[tmp_k] = zero; - } - for (int m = 0; m < b_end; m++) { + for (int m = 0; m < b_end; m++) { #ifndef USE_ROCM - res2 = {}; + res2 = {}; #else - res2.x = __half_as_ushort(__float2half(0)); - res2.y = __half_as_ushort(__float2half(0)); + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); #endif - res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), + blockvec[m][k + 0], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), + blockvec[m][k + 1], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), + blockvec[m][k + 2], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), + blockvec[m][k + 3], res2); #ifndef USE_ROCM - res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); #else - res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); + res[m] = __hadd( + res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); #endif - } - i += width; - k += 4; - } - for (int m = 0; m < b_end; m++) { - atomicAdd(&mul[(b + m) * width + w], res[m]); } + i += width; + k += 4; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } } - __global__ void gemm_half_q_half_alt_8bit_kernel( - const half2* __restrict__ vec, - const uint32_t* __restrict__ mat, - half* __restrict__ mul, - const half* __restrict__ scales, - const uint32_t* __restrict__ zeros, - const int* __restrict__ g_idx, - int batch, - int height, - int width -) -{ - int zero_width = width / 4; - int vec_height = height * 2; - const int blockwidth2 = BLOCK_KN_SIZE / 2; - int b = blockIdx.y * BLOCK_M_SIZE_MAX; - int b_end = min(BLOCK_M_SIZE_MAX, batch - b); - int h = BLOCK_KN_SIZE * blockIdx.z / 4; - int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; - int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - - __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; - if (threadIdx.x < h_end) { - for (int m = 0; m < b_end; ++m) { - blockvec[m][threadIdx.x] = - vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + - threadIdx.x]; - } + const half2* __restrict__ vec, const uint32_t* __restrict__ mat, + half* __restrict__ mul, const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, + int batch, int height, int width) { + int zero_width = width / 4; + int vec_height = height * 2; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int h = BLOCK_KN_SIZE * blockIdx.z / 4; + int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; } - - - if (blockIdx.z == 0) - { - for (int m = 0; m < b_end; m++) - mul[(b + m) * width + w] = __int2half_rn(0); + } + + if (blockIdx.z == 0) { + for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 4; + int k = 0; + int z_w = w / 4; + int z_mod = (w % 4) * 8; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[2]; + half2 zeros_tmp[2]; + for (int tmp_k = 0; tmp_k < 2; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, + __int2half_rn( + -((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), + __hmul(scale_f2, + __int2half_rn( + -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1))); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; } - __syncthreads(); - - int i = width * h + w; - int g_h = h * 4; - int k = 0; - int z_w = w / 4; - int z_mod = (w % 4) * 8; - half2 res2; - half res[BLOCK_M_SIZE_MAX] = {}; - - unsigned int tmp; - while (k < h_end) { - tmp = mat[i]; - half2 scales_tmp[2]; - half2 zeros_tmp[2]; - for (int tmp_k = 0; tmp_k < 2; tmp_k++) { - int g = g_idx[g_h + (k + tmp_k) * 2]; - int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; - half scale_f = scales[g * width + w]; - half scale_f2 = scales[g2 * width + w]; - half2 scale = __halves2half2(scale_f, scale_f2); - half2 zero = __halves2half2( - __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), - __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1)) - ); - scales_tmp[tmp_k] = scale; - zeros_tmp[tmp_k] = zero; - } - for (int m = 0; m < b_end; m++) { + for (int m = 0; m < b_end; m++) { #ifndef USE_ROCM - res2 = {}; + res2 = {}; #else - res2.x = __half_as_ushort(__float2half(0)); - res2.y = __half_as_ushort(__float2half(0)); + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); #endif - half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), __int2half_rn((tmp >> 8) & 0xFF)); - res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); - half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), __int2half_rn((tmp >> 24) & 0xFF)); - res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); + half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), + __int2half_rn((tmp >> 8) & 0xFF)); + res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), + blockvec[m][k + 0], res2); + half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), + __int2half_rn((tmp >> 24) & 0xFF)); + res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), + blockvec[m][k + 1], res2); #ifndef USE_ROCM - res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); #else - res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); + res[m] = __hadd( + res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); #endif - } - i += width; - k += 2; - } - for (int m = 0; m < b_end; m++) { - atomicAdd(&mul[(b + m) * width + w], res[m]); } + i += width; + k += 2; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } } -void gemm_half_q_half_alt -( - const half* a, - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_g_idx, - half* c, - int size_m, - int size_n, - int size_k, - int bit -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - blockDim.z = 1; - gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); - gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); - gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); - - auto kernel = gemm_half_q_half_alt_4bit_kernel; - if (bit == 8) { - kernel = gemm_half_q_half_alt_8bit_kernel; - } - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>> - ( - (const half2*) a, - b_q_weight, - c, - b_gptq_scales, - b_gptq_qzeros, - b_g_idx, - size_m, - size_k / 32 * bit, - size_n - ); +void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_g_idx, + half* c, int size_m, int size_n, int size_k, + int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); + gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + auto kernel = gemm_half_q_half_alt_4bit_kernel; + if (bit == 8) { + kernel = gemm_half_q_half_alt_8bit_kernel; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>( + (const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx, + size_m, size_k / 32 * bit, size_n); } -template -__global__ void reconstruct_gptq_kernel -( - const uint32_t* __restrict__ w, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int* __restrict__ g_idx, - const int height, - const int width, - const int group, - half* __restrict__ out -) -{ - // Start of block - - int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - int row = blockIdx.y * 32 / bit; - if (column >= width) return; - - // Views - - MatrixView_half_rw out_(out, height, width); - MatrixView_half w_scales_(w_scales, group, width); - T w_zeros_(w_zeros, group, width); - - uint32_t w_read = w[blockIdx.y * width + column]; - half* out_ptr = out_.item_ptr(row, column); - - #pragma unroll - for (int s = 0; s < 32; s += bit) - { - int group = g_idx[row + s / bit]; - half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; - half w_item = __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), w_scale); - *out_ptr = w_item; out_ptr += out_.width; - } +template +__global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int* __restrict__ g_idx, + const int height, const int width, + const int group, + half* __restrict__ out) { + // Start of block + + int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + int row = blockIdx.y * 32 / bit; + if (column >= width) return; + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + T w_zeros_(w_zeros, group, width); + + uint32_t w_read = w[blockIdx.y * width + column]; + half* out_ptr = out_.item_ptr(row, column); + +#pragma unroll + for (int s = 0; s < 32; s += bit) { + int group = g_idx[row + s / bit]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + half w_item = + __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), + w_scale); + *out_ptr = w_item; + out_ptr += out_.width; + } } -__global__ void reconstruct_gptq_3bit_kernel -( - const uint32_t* __restrict__ w, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int* __restrict__ g_idx, - const int height, - const int width, - const int group, - half* __restrict__ out -) -{ - // Start of block - int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - int row = blockIdx.y * 32; - if (column >= width) return; - - // Views - - MatrixView_half_rw out_(out, height, width); - MatrixView_half w_scales_(w_scales, group, width); - MatrixView_q3_row w_zeros_(w_zeros, group, width); - - uint32_t w1 = w[(blockIdx.y * 3) * width + column]; - uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; - uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; - half* out_ptr = out_.item_ptr(row, column); - - #pragma unroll - for (int i = 0; i < 32; i += 1) - { - int group = g_idx[row + i]; - half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; - int w_item; - if (i == 10) { - w_item = (w1 >> 30) | ((w2 << 2) & 0x4); - } else if (i == 21) { - w_item = (w2 >> 31) | ((w3 << 1) & 0x6); - } else if (i < 10) { - w_item = ((w1 >> (i * 3)) & 0x7); - } else if (i < 21) { - w_item = ((w2 >> (i * 3 - 32)) & 0x7); - } else { - w_item = ((w3 >> (i * 3 - 64)) & 0x7); - } - *out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale); - out_ptr += out_.width; +__global__ void reconstruct_gptq_3bit_kernel( + const uint32_t* __restrict__ w, const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx, + const int height, const int width, const int group, + half* __restrict__ out) { + // Start of block + int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + int row = blockIdx.y * 32; + if (column >= width) return; + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + MatrixView_q3_row w_zeros_(w_zeros, group, width); + + uint32_t w1 = w[(blockIdx.y * 3) * width + column]; + uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; + uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; + half* out_ptr = out_.item_ptr(row, column); + +#pragma unroll + for (int i = 0; i < 32; i += 1) { + int group = g_idx[row + i]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + int w_item; + if (i == 10) { + w_item = (w1 >> 30) | ((w2 << 2) & 0x4); + } else if (i == 21) { + w_item = (w2 >> 31) | ((w3 << 1) & 0x6); + } else if (i < 10) { + w_item = ((w1 >> (i * 3)) & 0x7); + } else if (i < 21) { + w_item = ((w2 >> (i * 3 - 32)) & 0x7); + } else { + w_item = ((w3 >> (i * 3 - 64)) & 0x7); } + *out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale); + out_ptr += out_.width; + } } -void reconstruct_gptq -( - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_g_idx, - half* out, - int height, - int width, - int groups, - int bit -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - gridDim.y = DIVIDE(height, 32 / bit); - gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); - - auto kernel = reconstruct_gptq_kernel; - if (bit == 2) { - kernel = reconstruct_gptq_kernel; - } else if (bit == 8) { - kernel = reconstruct_gptq_kernel; - } else if (bit == 3) { - kernel = reconstruct_gptq_3bit_kernel; - gridDim.y = DIVIDE(height, 32); - } - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>> - ( - b_q_weight, - b_gptq_scales, - b_gptq_qzeros, - b_g_idx, - height, - width, - groups, - out - ); +void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_g_idx, half* out, + int height, int width, int groups, int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, 32 / bit); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + auto kernel = reconstruct_gptq_kernel; + if (bit == 2) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 8) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 3) { + kernel = reconstruct_gptq_3bit_kernel; + gridDim.y = DIVIDE(height, 32); + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>(b_q_weight, b_gptq_scales, + b_gptq_qzeros, b_g_idx, height, + width, groups, out); } - -void gemm_half_q_half_cuda -( - cublasHandle_t cublas_handle, - const half* a, - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_g_idx, - half* c, - half* temp_dq, - int size_m, - int size_n, - int size_k, - int groups, - bool use_exllama, - int bit -) -{ - bool use_reconstruct; +void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_g_idx, + half* c, half* temp_dq, int size_m, int size_n, + int size_k, int groups, bool use_exllama, int bit) { + bool use_reconstruct; + if (use_exllama) { + use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || + (bit != 8 && size_m > MAX_Q_GEMM_ROWS)); + } else { + // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so + // we disabled them for now. + use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS); + } + if (use_reconstruct) { + // Reconstruct FP16 matrix, then cuBLAS if (use_exllama) { - use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || (bit != 8 && size_m > MAX_Q_GEMM_ROWS)); + reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, bit); } else { - // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so we disabled them for now. - use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS); + reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, bit); } - if (use_reconstruct) { - // Reconstruct FP16 matrix, then cuBLAS - if (use_exllama) { - reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, - size_k, size_n, groups, bit); - } - else - { - reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - temp_dq, size_k, size_n, groups, bit); - } - const half alpha = __float2half(1.0f); - const half beta = __float2half(0.0f); - cublasHgemm(cublas_handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - size_n, size_m, size_k, - &alpha, temp_dq, size_n, - a, size_k, - &beta, c, size_n); + const half alpha = __float2half(1.0f); + const half beta = __float2half(0.0f); + cublasHgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, size_n, size_m, size_k, + &alpha, temp_dq, size_n, a, size_k, &beta, c, size_n); + } else if (use_exllama) { + // Quantized matmul + int max_chunks = size_m / BLOCK_M_SIZE_MAX; + int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; + int last_chunk_size = size_m - last_chunk; + + if (max_chunks) { + gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, c, last_chunk, size_n, size_k, + BLOCK_M_SIZE_MAX, groups, bit); } - else if (use_exllama) - { - // Quantized matmul - int max_chunks = size_m / BLOCK_M_SIZE_MAX; - int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; - int last_chunk_size = size_m - last_chunk; - - if (max_chunks) - { - gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, - groups, bit); - } - if (last_chunk_size) - { - gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, - b_gptq_scales, b_g_idx, c + last_chunk * size_n, - last_chunk_size, size_n, size_k, last_chunk_size, - groups, bit); - } - } - else - { - gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - c, size_m, size_n, size_k, bit); + if (last_chunk_size) { + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, + b_gptq_qzeros, b_gptq_scales, b_g_idx, + c + last_chunk * size_n, last_chunk_size, + size_n, size_k, last_chunk_size, groups, bit); } + } else { + gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + c, size_m, size_n, size_k, bit); + } } -__global__ void shuffle_4bit_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; } +__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_4bit_8(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 8; + } } -__global__ void shuffle_8bit_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; } +__global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_8bit_4(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 4; + } } -__global__ void shuffle_2bit_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; } +__global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_2bit_16(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 16; + } } -__global__ void shuffle_3bit_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; } +__global__ void shuffle_3bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_3bit_32(b_ptr, size_n); + b_ptr += 3 * size_n; + k += 32; + } } -__global__ void make_sequential_4bit_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - int w_new2_row = blockIdx.y; - int q_perm_idx = w_new2_row << 3; - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 8; i++) - { - int source_row = q_perm[q_perm_idx++]; - - int w2_row = source_row >> 3; - int w2_subrow = source_row & 0x07; - int w2_row_shift = w2_subrow << 2; - int wnew2_row_shift = i << 2; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x0000000f0000000f; - src <<= wnew2_row_shift; - dst |= src; - } - w_new2[w_new2_row * w2_stride + w2_column] = dst; +__global__ void make_sequential_4bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 3; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 8; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; } -__global__ void make_sequential_2bit_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - int w_new2_row = blockIdx.y; - int q_perm_idx = w_new2_row << 4; - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 16; i++) - { - int source_row = q_perm[q_perm_idx++]; - - int w2_row = source_row >> 4; - int w2_subrow = source_row & 0x0f; - int w2_row_shift = w2_subrow << 1; - int wnew2_row_shift = i << 1; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x0000000300000003; - src <<= wnew2_row_shift; - dst |= src; - } - w_new2[w_new2_row * w2_stride + w2_column] = dst; +__global__ void make_sequential_2bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 4; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 16; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 4; + int w2_subrow = source_row & 0x0f; + int w2_row_shift = w2_subrow << 1; + int wnew2_row_shift = i << 1; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000300000003; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; } -__global__ void make_sequential_3bit_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, - const int w_width -) -{ - int w_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w_column >= w_width) return; - int w_new_row = blockIdx.y * 3; - int q_perm_idx = blockIdx.y << 5; - uint32_t dst[3] = {0, 0, 0}; - - #pragma unroll - for (int i = 0; i < 32; i++) - { - int source_row = q_perm[q_perm_idx++]; - int z_w = (source_row / 32) * 3; - int z_mod = source_row % 32; - int z_bit; - - if (z_mod != 10){ - if (z_mod != 21){ - z_bit = z_mod; - if (z_bit > 21){ - z_bit *= 3; - z_bit -= 64; - z_w += 2; - } else if (z_bit > 10){ - z_bit *= 3; - z_bit -= 32; - z_w += 1; - } else { - z_bit *= 3; - } - } else { - z_w += 1; - } - } - - uint64_t src; - if (z_mod == 10) { - src = (w[z_w * w_width + w_column] >> 30) | ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4); - } else if (z_mod == 21){ - src = (w[z_w * w_width + w_column] >> 31) | ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6); +__global__ void make_sequential_3bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + int w_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w_column >= w_width) return; + int w_new_row = blockIdx.y * 3; + int q_perm_idx = blockIdx.y << 5; + uint32_t dst[3] = {0, 0, 0}; + +#pragma unroll + for (int i = 0; i < 32; i++) { + int source_row = q_perm[q_perm_idx++]; + int z_w = (source_row / 32) * 3; + int z_mod = source_row % 32; + int z_bit; + + if (z_mod != 10) { + if (z_mod != 21) { + z_bit = z_mod; + if (z_bit > 21) { + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10) { + z_bit *= 3; + z_bit -= 32; + z_w += 1; } else { - src = w[z_w * w_width + w_column]; - src >>= z_bit; - src &= 0x07; + z_bit *= 3; } + } else { + z_w += 1; + } + } - z_w = 0; - if (i != 10){ - if (i != 21){ - z_bit = i; - if (z_bit > 21){ - z_bit *= 3; - z_bit -= 64; - z_w += 2; - } else if (z_bit > 10){ - z_bit *= 3; - z_bit -= 32; - z_w += 1; - } else { - z_bit *= 3; - } - } else { - z_w += 1; - } - } - if (i == 10) { - dst[z_w] |= (src & 0x03) << 30; - dst[z_w + 1] |= ((src & 0x4) >> 2); - } else if (i == 21) { - dst[z_w] |= (src & 0x01) << 31; - dst[z_w + 1] |= ((src & 0x6) >> 1); + uint64_t src; + if (z_mod == 10) { + src = (w[z_w * w_width + w_column] >> 30) | + ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4); + } else if (z_mod == 21) { + src = (w[z_w * w_width + w_column] >> 31) | + ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6); + } else { + src = w[z_w * w_width + w_column]; + src >>= z_bit; + src &= 0x07; + } + + z_w = 0; + if (i != 10) { + if (i != 21) { + z_bit = i; + if (z_bit > 21) { + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10) { + z_bit *= 3; + z_bit -= 32; + z_w += 1; } else { - dst[z_w] |= (src << z_bit); + z_bit *= 3; } + } else { + z_w += 1; + } + } + if (i == 10) { + dst[z_w] |= (src & 0x03) << 30; + dst[z_w + 1] |= ((src & 0x4) >> 2); + } else if (i == 21) { + dst[z_w] |= (src & 0x01) << 31; + dst[z_w + 1] |= ((src & 0x6) >> 1); + } else { + dst[z_w] |= (src << z_bit); } - w_new[w_new_row * w_width + w_column] = dst[0]; - w_new[(w_new_row + 1) * w_width + w_column] = dst[1]; - w_new[(w_new_row + 2) * w_width + w_column] = dst[2]; + } + w_new[w_new_row * w_width + w_column] = dst[0]; + w_new[(w_new_row + 1) * w_width + w_column] = dst[1]; + w_new[(w_new_row + 2) * w_width + w_column] = dst[2]; } -__global__ void make_sequential_8bit_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - int w_new2_row = blockIdx.y; - int q_perm_idx = w_new2_row << 2; - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 4; i++) - { - int source_row = q_perm[q_perm_idx++]; - - int w2_row = source_row >> 2; - int w2_subrow = source_row & 0x03; - int w2_row_shift = w2_subrow << 3; - int wnew2_row_shift = i << 3; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x000000ff000000ff; - src <<= wnew2_row_shift; - dst |= src; - } - w_new2[w_new2_row * w2_stride + w2_column] = dst; +__global__ void make_sequential_8bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 2; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 2; + int w2_subrow = source_row & 0x03; + int w2_row_shift = w2_subrow << 3; + int wnew2_row_shift = i << 3; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x000000ff000000ff; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; } +void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height, + int width, int bit) { + if (q_perm) { + uint32_t* new_qweight = NULL; + cudaMalloc(&new_qweight, height / 32 * bit * width * sizeof(uint32_t)); -void shuffle_exllama_weight -( - uint32_t* q_weight, - int* q_perm, - int height, - int width, - int bit -) -{ - if (q_perm) - { - uint32_t* new_qweight = NULL; - cudaMalloc(&new_qweight, height / 32 * bit * width * sizeof(uint32_t)); - - dim3 blockDim, gridDim; - blockDim.x = THREADS_X; - blockDim.y = 1; - gridDim.x = DIVIDE(width, THREADS_X); - gridDim.y = height / 32 * bit; - - auto kernel = make_sequential_4bit_kernel; - if (bit == 2) { - kernel = make_sequential_2bit_kernel; - } else if (bit == 3) { - kernel = make_sequential_3bit_kernel; - gridDim.y = height / 32; - } else if (bit == 8) { - kernel = make_sequential_8bit_kernel; - } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>> - ( - q_weight, - new_qweight, - q_perm, - width - ); - // Replace qweights - cudaMemcpyAsync(q_weight, new_qweight, height / 32 * bit * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); - // Cleanup - cudaDeviceSynchronize(); - cudaFree(new_qweight); - } dim3 blockDim, gridDim; blockDim.x = THREADS_X; blockDim.y = 1; gridDim.x = DIVIDE(width, THREADS_X); - gridDim.y = 1; - auto shuffle_kernel = shuffle_4bit_kernel; + gridDim.y = height / 32 * bit; + + auto kernel = make_sequential_4bit_kernel; if (bit == 2) { - shuffle_kernel = shuffle_2bit_kernel; + kernel = make_sequential_2bit_kernel; } else if (bit == 3) { - shuffle_kernel = shuffle_3bit_kernel; + kernel = make_sequential_3bit_kernel; + gridDim.y = height / 32; } else if (bit == 8) { - shuffle_kernel = shuffle_8bit_kernel; + kernel = make_sequential_8bit_kernel; } const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - shuffle_kernel<<>>(q_weight, height, width); + kernel<<>>(q_weight, new_qweight, q_perm, + width); + // Replace qweights + cudaMemcpyAsync(q_weight, new_qweight, + height / 32 * bit * width * sizeof(uint32_t), + cudaMemcpyDeviceToDevice); + // Cleanup + cudaDeviceSynchronize(); + cudaFree(new_qweight); + } + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = 1; + auto shuffle_kernel = shuffle_4bit_kernel; + if (bit == 2) { + shuffle_kernel = shuffle_2bit_kernel; + } else if (bit == 3) { + shuffle_kernel = shuffle_3bit_kernel; + } else if (bit == 8) { + shuffle_kernel = shuffle_8bit_kernel; + } + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + shuffle_kernel<<>>(q_weight, height, width); } } // namespace gptq } // namespace vllm -torch::Tensor gptq_gemm -( - torch::Tensor a, - torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, - torch::Tensor b_g_idx, - bool use_exllama, - int bit -) -{ - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); - at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options); - - vllm::gptq::gemm_half_q_half_cuda - ( - at::cuda::getCurrentCUDABlasHandle(), - (const half*) a.data_ptr(), - (const uint32_t*) b_q_weight.data_ptr(), - (const uint32_t*)b_gptq_qzeros.data_ptr(), - (const half*) b_gptq_scales.data_ptr(), - b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(), - (half*) c.data_ptr(), - (half*) temp_dq.data_ptr(), - c.size(0), // m - c.size(1), // n - a.size(1), // k - b_gptq_qzeros.size(0), // group number - use_exllama, - bit - ); - return c; +torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + bool use_exllama, int bit) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); + at::Tensor temp_dq = torch::empty( + {b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options); + + vllm::gptq::gemm_half_q_half_cuda( + at::cuda::getCurrentCUDABlasHandle(), (const half*)a.data_ptr(), + (const uint32_t*)b_q_weight.data_ptr(), + (const uint32_t*)b_gptq_qzeros.data_ptr(), + (const half*)b_gptq_scales.data_ptr(), + b_g_idx.device().is_meta() ? NULL : (const int*)b_g_idx.data_ptr(), + (half*)c.data_ptr(), (half*)temp_dq.data_ptr(), + c.size(0), // m + c.size(1), // n + a.size(1), // k + b_gptq_qzeros.size(0), // group number + use_exllama, bit); + return c; } -void gptq_shuffle -( - torch::Tensor q_weight, - torch::Tensor q_perm, - int bit -) -{ - const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); - vllm::gptq::shuffle_exllama_weight( - (uint32_t*) q_weight.data_ptr(), - q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(), - q_weight.size(0) * 32 / bit, - q_weight.size(1), - bit - ); +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); + vllm::gptq::shuffle_exllama_weight( + (uint32_t*)q_weight.data_ptr(), + q_perm.device().is_meta() || q_perm.numel() == 0 + ? NULL + : (int*)q_perm.data_ptr(), + q_weight.size(0) * 32 / bit, q_weight.size(1), bit); } diff --git a/csrc/quantization/gptq/qdq_2.cuh b/csrc/quantization/gptq/qdq_2.cuh index 295872a91de37..ca0f810608d1b 100644 --- a/csrc/quantization/gptq/qdq_2.cuh +++ b/csrc/quantization/gptq/qdq_2.cuh @@ -14,71 +14,60 @@ namespace gptq { // // ffddbb99 77553311 eeccaa88 66442200 -__forceinline__ __device__ void shuffle_2bit_16 -( - uint32_t* q, - int stride -) -{ - uint32_t qa = q[0]; - uint32_t qb = 0; +__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; - #pragma unroll - for (int i = 0; i < 8; i++) - { - uint32_t qa0 = qa & 0x03; - uint32_t qa1 = (qa & 0x0c) >> 2; - qa >>= 4; - qb |= (qa1 << (i * 2 + 16)); - qb |= (qa0 << (i * 2)); - } - q[0] = qb; +#pragma unroll + for (int i = 0; i < 8; i++) { + uint32_t qa0 = qa & 0x03; + uint32_t qa1 = (qa & 0x0c) >> 2; + qa >>= 4; + qb |= (qa1 << (i * 2 + 16)); + qb |= (qa0 << (i * 2)); + } + q[0] = qb; } -__forceinline__ __device__ void dequant_2bit_16 -( - const uint32_t q_0, - half2 (&dq)[8], - int stride, - const uint32_t zero -) -{ - const uint32_t c0 = 0x64006400; - const half y4_ = __float2half_rn(1.0f / 4.0f); - const half y16_ = __float2half_rn(1.0f / 16.0f); - const half y64_ = __float2half_rn(1.0f / 64.0f); - const half2 y4 = __halves2half2(y4_, y4_); - const half2 y16 = __halves2half2(y16_, y16_); - const half2 y64 = __halves2half2(y64_, y64_); +__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0, + half2 (&dq)[8], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y4_ = __float2half_rn(1.0f / 4.0f); + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y4 = __halves2half2(y4_, y4_); + const half2 y16 = __halves2half2(y16_, y16_); + const half2 y64 = __halves2half2(y64_, y64_); - const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); - const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); - const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); - const half2 z1 = __half2half2(z1_.as_half); - const half2 z4 = __half2half2(z4_); - const half2 z16 = __half2half2(z16_); - const half2 z64 = __half2half2(z64_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z4 = __half2half2(z4_); + const half2 z16 = __half2half2(z16_); + const half2 z64 = __half2half2(z64_); - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 - half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 - half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 - qa >>= 8; - half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 - half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 - half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 - half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 + half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 + half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 + qa >>= 8; + half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 + half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 + half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 + half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 - dq[0] = __hadd2(q0.as_half2, z1); - dq[1] = __hfma2(q1.as_half2, y4, z4); - dq[2] = __hfma2(q2.as_half2, y16, z16); - dq[3] = __hfma2(q3.as_half2, y64, z64); - dq[4] = __hadd2(q4.as_half2, z1); - dq[5] = __hfma2(q5.as_half2, y4, z4); - dq[6] = __hfma2(q6.as_half2, y16, z16); - dq[7] = __hfma2(q7.as_half2, y64, z64); + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y4, z4); + dq[2] = __hfma2(q2.as_half2, y16, z16); + dq[3] = __hfma2(q3.as_half2, y64, z64); + dq[4] = __hadd2(q4.as_half2, z1); + dq[5] = __hfma2(q5.as_half2, y4, z4); + dq[6] = __hfma2(q6.as_half2, y16, z16); + dq[7] = __hfma2(q7.as_half2, y64, z64); } } // namespace gptq diff --git a/csrc/quantization/gptq/qdq_3.cuh b/csrc/quantization/gptq/qdq_3.cuh index 3e7ecde752ba3..0d5c2adf5dbbe 100644 --- a/csrc/quantization/gptq/qdq_3.cuh +++ b/csrc/quantization/gptq/qdq_3.cuh @@ -11,128 +11,136 @@ namespace gptq { // vjjjhhhf ffdddbbb uiiiggge eecccaaa // vtttrrrp ppnnnlll usssqqqo oommmkkk -__forceinline__ __device__ void shuffle_3bit_32 -( - uint32_t* q, - int stride -) -{ - uint32_t qa = q[0 * stride]; - uint32_t qb = q[1 * stride]; - uint32_t qc = q[2 * stride]; - - // qa: aa999888 77766655 54443332 22111000 - // qb: lkkkjjji iihhhggg fffeeedd dcccbbba - // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll - - uint32_t qd = qc >> 26; - qc <<= 4; - qc |= qb >> 28; - qb <<= 2; - qb |= qa >> 30; - - // qa: ..999888 77766655 54443332 22111000 - // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa - // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk - // qd: vvvuuu - - uint32_t za = 0; - uint32_t zb = 0; - uint32_t zc = 0; - - for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); } - for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); } - for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); } - - // za: 9997775 55333111 8886664 44222000 - // zb: jjjhhhf ffdddbbb iiiggge eecccaaa - // zc: tttrrrp ppnnnlll sssqqqo oommmkkk - // qd: vvvuuu - - za |= ((qd & 0x01) >> 0) << 15; - zb |= ((qd & 0x02) >> 1) << 15; - zc |= ((qd & 0x04) >> 2) << 15; - za |= ((qd & 0x08) >> 3) << 31; - zb |= ((qd & 0x10) >> 4) << 31; - zc |= ((qd & 0x20) >> 5) << 31; - - // za: v9997775 55333111 u8886664 44222000 (u, v lsb) - // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa - // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk - - q[0 * stride] = za; - q[1 * stride] = zb; - q[2 * stride] = zc; -} - -__forceinline__ __device__ void dequant_3bit_32 -( - const uint32_t q_0, - const uint32_t q_1, - const uint32_t q_2, - half2 (&dq)[16], - int stride, - const uint32_t zero -) -{ - const uint32_t c0 = 0x64006400; - const half y8_ = __float2half_rn(1.0f / 8.0f); - const half y64_ = __float2half_rn(1.0f / 64.0f); - const half2 y8 = __halves2half2(y8_, y8_); - const half2 y64 = __halves2half2(y64_, y64_); - const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); - const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); - const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); - const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half); - const half2 z8 = __halves2half2(z8_, z8_); - const half2 z64 = __halves2half2(z64_, z64_); - - uint32_t qa = q_0; - uint32_t qb = q_1; - uint32_t qc = q_2; - - half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 +__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) { + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + + // qa: aa999888 77766655 54443332 22111000 + // qb: lkkkjjji iihhhggg fffeeedd dcccbbba + // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll + + uint32_t qd = qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: ..999888 77766655 54443332 22111000 + // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa + // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk + // qd: vvvuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + + for (int i = 0; i < 5; i++) { + uint32_t t0 = qa & 0x07; + uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; - half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 - half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 - half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 - qa >>= 9; - qa &= 0x00010001; - half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 - half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + za |= (t0 << (i * 3)); + za |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qb & 0x07; + uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; - half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 - half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 - half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 - qb >>= 8; - qb &= 0x00020002; - half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 - half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + zb |= (t0 << (i * 3)); + zb |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qc & 0x07; + uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; - half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 - half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 - half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 - qc >>= 7; - qc &= 0x00040004; - half2_uint32 q15((qa | qb | qc) | c0); - - dq[ 0] = __hadd2( q0.as_half2, z1); - dq[ 1] = __hfma2( q1.as_half2, y8, z8); - dq[ 2] = __hadd2( q2.as_half2, z1); - dq[ 3] = __hfma2( q3.as_half2, y8, z8); - dq[ 4] = __hfma2( q4.as_half2, y64, z64); - dq[ 5] = __hadd2( q5.as_half2, z1); - dq[ 6] = __hfma2( q6.as_half2, y8, z8); - dq[ 7] = __hadd2( q7.as_half2, z1); - dq[ 8] = __hfma2( q8.as_half2, y8, z8); - dq[ 9] = __hfma2( q9.as_half2, y64, z64); - dq[10] = __hadd2(q10.as_half2, z1); - dq[11] = __hfma2(q11.as_half2, y8, z8); - dq[12] = __hadd2(q12.as_half2, z1); - dq[13] = __hfma2(q13.as_half2, y8, z8); - dq[14] = __hfma2(q14.as_half2, y64, z64); - dq[15] = __hadd2(q15.as_half2, z1); + zc |= (t0 << (i * 3)); + zc |= (t1 << (i * 3 + 16)); + } + + // za: 9997775 55333111 8886664 44222000 + // zb: jjjhhhf ffdddbbb iiiggge eecccaaa + // zc: tttrrrp ppnnnlll sssqqqo oommmkkk + // qd: vvvuuu + + za |= ((qd & 0x01) >> 0) << 15; + zb |= ((qd & 0x02) >> 1) << 15; + zc |= ((qd & 0x04) >> 2) << 15; + za |= ((qd & 0x08) >> 3) << 31; + zb |= ((qd & 0x10) >> 4) << 31; + zc |= ((qd & 0x20) >> 5) << 31; + + // za: v9997775 55333111 u8886664 44222000 (u, v lsb) + // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa + // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; +} + +__forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y8_ = __float2half_rn(1.0f / 8.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y8 = __halves2half2(y8_, y8_); + const half2 y64 = __halves2half2(y64_, y64_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half); + const half2 z8 = __halves2half2(z8_, z8_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + + half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 + qa >>= 6; + half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 + half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 + qa >>= 9; + qa &= 0x00010001; + half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 + half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + qb >>= 6; + half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 + half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 + half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 + qb >>= 8; + qb &= 0x00020002; + half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 + half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + qc >>= 6; + half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 + half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 + qc >>= 7; + qc &= 0x00040004; + half2_uint32 q15((qa | qb | qc) | c0); + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y8, z8); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y8, z8); + dq[4] = __hfma2(q4.as_half2, y64, z64); + dq[5] = __hadd2(q5.as_half2, z1); + dq[6] = __hfma2(q6.as_half2, y8, z8); + dq[7] = __hadd2(q7.as_half2, z1); + dq[8] = __hfma2(q8.as_half2, y8, z8); + dq[9] = __hfma2(q9.as_half2, y64, z64); + dq[10] = __hadd2(q10.as_half2, z1); + dq[11] = __hfma2(q11.as_half2, y8, z8); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y8, z8); + dq[14] = __hfma2(q14.as_half2, y64, z64); + dq[15] = __hadd2(q15.as_half2, z1); } } // namespace gptq diff --git a/csrc/quantization/gptq/qdq_4.cuh b/csrc/quantization/gptq/qdq_4.cuh index 881f353f6564d..7f65d2d2819b1 100644 --- a/csrc/quantization/gptq/qdq_4.cuh +++ b/csrc/quantization/gptq/qdq_4.cuh @@ -13,133 +13,112 @@ namespace gptq { // // 77775555 33331111 66664444 22220000 -__forceinline__ __device__ void shuffle_4bit_8 -( - uint32_t* q, - int stride -) -{ - uint32_t qa = q[0]; - uint32_t qb = 0; - - #pragma unroll - for (int i = 0; i < 4; i++) - { - uint32_t qa0 = qa & 0x0f; - uint32_t qa1 = (qa & 0xf0) >> 4; - qa >>= 8; - qb |= (qa1 << (i * 4 + 16)); - qb |= (qa0 << (i * 4)); - } - q[0] = qb; -} - -__forceinline__ __device__ void dequant_4bit_8 -( - const uint32_t q_0, - half2 (&dq)[4], - int stride, - const uint32_t zero -) -{ - const uint32_t c0 = 0x64006400; - const half y16_ = __float2half_rn(1.0f / 16.0f); - const half2 y16 = __halves2half2(y16_, y16_); - const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); - const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - const half2 z1 = __half2half2(z1_.as_half); - const half2 z16 = __half2half2(z16_); - - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 +__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; qa >>= 8; - half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 - half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} - dq[0] = __hadd2(q0.as_half2, z1); - dq[1] = __hfma2(q1.as_half2, y16, z16); - dq[2] = __hadd2(q2.as_half2, z1); - dq[3] = __hfma2(q3.as_half2, y16, z16); +__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0, + half2 (&dq)[4], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z16 = __half2half2(z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); } -__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale -( - const uint32_t zero, - const half scale, - half2 (&z1z16)[2], - half2 (&y1y16)[2] -) -{ - half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); - half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale( + const uint32_t zero, const half scale, half2 (&z1z16)[2], + half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - half2 scale2 = __half2half2(scale); + half2 scale2 = __half2half2(scale); - z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); - z1z16[1] = __hmul2(scale2, __half2half2(z16)); + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); - const half y1 = __float2half_rn(1.0f); - const half y16 = __float2half_rn(1.0f / 16.0f); + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); - y1y16[0] = __hmul2(scale2, __half2half2(y1)); - y1y16[1] = __hmul2(scale2, __half2half2(y16)); + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); } -__forceinline__ __device__ void dequant_4bit_8_prep_zero -( - const uint32_t zero, - half2(&z1z16)[2], - half2(&y1y16)[2] -) -{ - half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); - half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); +__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero, + half2 (&z1z16)[2], + half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - z1z16[0] = __half2half2(z1.as_half); - z1z16[1] = __half2half2(z16); + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); - const half y1 = __float2half_rn(1.0f); - const half y16 = __float2half_rn(1.0f / 16.0f); + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); - y1y16[0] = __half2half2(y1); - y1y16[1] = __half2half2(y16); + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); } - -__forceinline__ __device__ void dequant_4bit_8_gptq -( - const uint32_t q_0, - half2 (&dq)[4], - half2 (&z1z16)[2], - half2 (&y1y16)[2], - int stride, - bool scaled -) -{ - const uint32_t c0 = 0x64006400; - - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) - half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) - qa >>= 8; - half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) - half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) - - if (scaled) - { - dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) - dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) - dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); - dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); - } - else - { - dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) - dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z ) - dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) - dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) - } +__forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1z16)[2], + half2 (&y1y16)[2], + int stride, bool scaled) { + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | + c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | + c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | + c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | + c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) { + dq[0] = __hfma2(q0.as_half2, y1y16[0], + z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } else { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], + z1z16[1]); // half2( q[6] - z, q[7] - z ) + } } } // namespace gptq } // namespace vllm diff --git a/csrc/quantization/gptq/qdq_8.cuh b/csrc/quantization/gptq/qdq_8.cuh index 0c7ad7876140b..feb5d220424b0 100644 --- a/csrc/quantization/gptq/qdq_8.cuh +++ b/csrc/quantization/gptq/qdq_8.cuh @@ -10,28 +10,18 @@ Copied from https://github.com/turboderp/exllamav2 namespace vllm { namespace gptq { -__forceinline__ __device__ void shuffle_8bit_4 -( - uint32_t* q, - int stride -) -{ -} - -__forceinline__ __device__ void dequant_8bit_8 -( - const uint32_t q_0, - const uint32_t q_1, - half2 (&dq)[4], - int stride, - const uint32_t zero -) -{ - half dqh[8]; - for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), zero); - for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); - - for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {} + +__forceinline__ __device__ void dequant_8bit_8(const uint32_t q_0, + const uint32_t q_1, + half2 (&dq)[4], int stride, + const uint32_t zero) { + half dqh[8]; + for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero); + for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); + + for (int i = 0; i < 4; i++) + dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); } } // namespace gptq diff --git a/csrc/quantization/gptq/qdq_util.cuh b/csrc/quantization/gptq/qdq_util.cuh index 1722a9aa6cb34..9426408fec502 100644 --- a/csrc/quantization/gptq/qdq_util.cuh +++ b/csrc/quantization/gptq/qdq_util.cuh @@ -8,51 +8,47 @@ Copied from https://github.com/turboderp/exllamav2 namespace vllm { namespace gptq { -union half2_uint32 -{ - uint32_t as_uint32; - half2 as_half2; - __device__ half2_uint32(uint32_t val) : as_uint32(val) {} - __device__ half2_uint32(half2 val) : as_half2(val) {} +union half2_uint32 { + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} }; -union half_uint16 -{ - uint16_t as_uint16; - half as_half; - __device__ half_uint16(uint16_t val) : as_uint16(val) {} - __device__ half_uint16(half val) : as_half(val) {} +union half_uint16 { + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} }; // Max_scale premultiplied by 1/256 -__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) -{ - int qs_i = qs + 1; - half qs_h = __int2half_rn(qs_i * qs_i); - qs_h = __hmul(qs_h, max_scale); - return qs_h; +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) { + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; } -__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) -{ - return __hmul(__int2half_rn(q - qzero), scale); +__forceinline__ __device__ half dq(const int q, const int qzero, + const half scale) { + return __hmul(__int2half_rn(q - qzero), scale); } -__forceinline__ __device__ half dq_ns(const int q, const int qzero) -{ - //return __hsub(__int2half_rn(q), __int2half_rn(qzero)); - return __int2half_rn(q - qzero); +__forceinline__ __device__ half dq_ns(const int q, const int qzero) { + // return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); } -__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) -{ - return (int)((q >> shift) & mask); +__forceinline__ __device__ int exb(const uint32_t q, const int shift, + const int mask) { + return (int)((q >> shift) & mask); } -__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) -{ - return (int)(__funnelshift_rc(q0, q1, shift) & mask); +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, + const int shift, const int mask) { + return (int)(__funnelshift_rc(q0, q1, shift) & mask); } } // namespace gptq diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 34950a5d13cf5..c573b9041065b 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -22,53 +22,58 @@ #include "gptq_marlin.cuh" #include "gptq_marlin_dtypes.cuh" -#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) static_assert(\ - std::is_same::value || std::is_same::value, \ - "only float16 and bfloat16 is supported"); - -template inline std::string str(T x) { return std::to_string(x); } +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +template +inline std::string str(T x) { + return std::to_string(x); +} namespace gptq_marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, - int const *__restrict__ perm_int_ptr, - int4 *__restrict__ out_int4_ptr, int size_m, +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) {} -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void -Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int *__restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) {} -} // namespace gptq_marlin +} // namespace gptq_marlin -torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &g_idx, - torch::Tensor &perm, torch::Tensor &workspace, +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full) { TORCH_CHECK_NOT_IMPLEMENTED(false, @@ -81,24 +86,26 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. template -__device__ inline void mma(const typename ScalarType::FragA &a_frag, - const typename ScalarType::FragB &frag_b, - typename ScalarType::FragC &frag_c) { - const uint32_t *a = reinterpret_cast(&a_frag); - const uint32_t *b = reinterpret_cast(&frag_b); - float *c = reinterpret_cast(&frag_c); +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); if constexpr (std::is_same::value) { - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), - "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } else if constexpr (std::is_same::value) { - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), - "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } else { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); } @@ -107,8 +114,9 @@ __device__ inline void mma(const typename ScalarType::FragA &a_frag, // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. template -__device__ inline void ldsm4(typename ScalarType::FragA &frag_a, const void *smem_ptr) { - uint32_t *a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) @@ -118,7 +126,8 @@ __device__ inline void ldsm4(typename ScalarType::FragA &frag_a, const // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. -template __device__ inline int lop3(int a, int b, int c) { +template +__device__ inline int lop3(int a, int b, int c) { int res; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) @@ -140,8 +149,10 @@ __device__ inline uint32_t prmt(uint32_t a) { // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 // values. We mostly follow the strategy in the link below, with some small // changes: -// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 template __device__ inline typename ScalarType::FragB dequant_4bit(int q) { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); @@ -161,16 +172,17 @@ __device__ inline typename ScalarType::FragB dequant_4bit(int q) { const int MUL = 0x2c002c00; const int ADD = 0xd480d480; typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } template <> -__device__ inline typename ScalarType::FragB dequant_4bit(int q) { +__device__ inline typename ScalarType::FragB +dequant_4bit(int q) { static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t EX = 0x43004300; @@ -184,7 +196,7 @@ __device__ inline typename ScalarType::FragB dequant_4bit(&lo), + frag_b[0] = __hfma2(*reinterpret_cast(&lo), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); frag_b[1] = __hfma2(*reinterpret_cast(&hi), @@ -193,10 +205,12 @@ __device__ inline typename ScalarType::FragB dequant_4bit __device__ inline typename ScalarType::FragB dequant_8bit(int q) { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); @@ -214,24 +228,26 @@ __device__ inline typename ScalarType::FragB dequant_8bit(int q) { static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); return frag_b; } template <> -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { +__device__ inline typename ScalarType::FragB +dequant_8bit(int q) { typename ScalarType::FragB frag_b; float fp32_intermediates[4]; - uint32_t * fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); fp32_intermediates[0] -= 8388736.f; @@ -240,8 +256,10 @@ __device__ inline typename ScalarType::FragB dequant_8bit(&frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); return frag_b; } @@ -249,30 +267,32 @@ __device__ inline typename ScalarType::FragB dequant_8bit -__device__ inline void scale(typename ScalarType::FragB &frag_b, - typename ScalarType::FragS &frag_s, int i) { +__device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + scalar_t2 s = + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } // Same as above, but for act_order (each K is multiplied individually) template -__device__ inline void scale4(typename ScalarType::FragB &frag_b, - typename ScalarType::FragS &frag_s_1, - typename ScalarType::FragS &frag_s_2, - typename ScalarType::FragS &frag_s_3, - typename ScalarType::FragS &frag_s_4, +__device__ inline void scale4(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, + typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, int i) { - using scalar_t2 = typename ScalarType::scalar_t2; + using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s_val_1_2; - s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; scalar_t2 s_val_3_4; - s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; frag_b[0] = __hmul2(frag_b[0], s_val_1_2); frag_b[1] = __hmul2(frag_b[1], s_val_3_4); @@ -280,14 +300,15 @@ __device__ inline void scale4(typename ScalarType::FragB &frag_b, // Given 2 floats multiply by 2 scales (halves) template -__device__ inline void scale_float(float *c, typename ScalarType::FragS &s) { - scalar_t *s_ptr = reinterpret_cast(&s); +__device__ inline void scale_float(float* c, + typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); } // Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int *lock, int count) { +__device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { int state = -1; do @@ -302,7 +323,7 @@ __device__ inline void barrier_acquire(int *lock, int count) { } // Release barrier and increment visitation count. -__device__ inline void barrier_release(int *lock, bool reset = false) { +__device__ inline void barrier_release(int* lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { @@ -321,11 +342,10 @@ __device__ inline void barrier_release(int *lock, bool reset = false) { // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. -__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, - int const *__restrict__ perm_int_ptr, - int4 *__restrict__ out_int4_ptr, int size_m, +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) { - int start_row = block_rows * blockIdx.x; int finish_row = start_row + block_rows; if (finish_row > size_m) { @@ -341,9 +361,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, int offset = row * row_stride; - half const *a_row_half = - reinterpret_cast(a_int4_ptr + offset); - half *out_half = reinterpret_cast(out_int4_ptr + offset); + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); int base_k = 0; @@ -374,31 +393,32 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, } } -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void -Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int *__restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * @@ -445,11 +465,11 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int slice_row = (iters * blockIdx.x) % k_tiles; int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice + int slice_iters; // number of threadblock tiles in the current slice int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers @@ -465,27 +485,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk auto init_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) - slice_iters = 0; - if (slice_iters == 0) - return; - if (slice_row + slice_iters > k_tiles) - slice_iters = k_tiles - slice_row; + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = div_ceil(k_tiles - col_off, iters); - if (col_off > 0) - slice_count++; + if (col_off > 0) slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) - slice_idx--; + if (col_off > 0) slice_idx--; } } if (slice_col == n_tiles) { @@ -605,7 +620,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; @@ -623,13 +638,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < thread_m_blocks; j++) a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); @@ -639,30 +654,30 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4 *B_ptr[b_sh_wr_iters]; -#pragma unroll + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. - int4 *sh_a = sh; - int4 *sh_b = sh_a + (stages * a_sh_stage); - int4 *sh_g_idx = sh_b + (stages * b_sh_stage); - int4 *sh_s = sh_g_idx + (stages * g_idx_stage); + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order // Zero accumulators. auto zero_accums = [&]() { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; + reinterpret_cast(frag_c)[i] = 0; }; int sh_first_group_id = -1; @@ -706,18 +721,18 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i]); } - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < b_thread_vecs; j++) { cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); } @@ -730,10 +745,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int full_pipe = a_off; int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; if (cur_k < prob_k && cur_k < slice_k_finish) { - int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int4 const *cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); if (threadIdx.x < g_idx_stage) { cp_async4_pred(&sh_g_idx_stage[threadIdx.x], @@ -742,7 +757,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } } else { if constexpr (group_blocks != -1) { - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + int4* sh_s_stage = sh_s + s_sh_stage * pipe; if constexpr (group_blocks >= thread_k_blocks) { // Only fetch scales if this tile starts a new group @@ -782,15 +797,16 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Load the next sub-tile from the current location in the shared memory pipe // into the current register buffer. auto fetch_to_registers = [&](int k, int pipe) { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + ldsm4(frag_a[k % 2][i], + &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll + #pragma unroll for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( + frag_b_quant[k % 2][i] = *reinterpret_cast( &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); } }; @@ -805,8 +821,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk return; } - int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); int group_id_1 = sh_g_idx_int_ptr[0]; int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; @@ -822,10 +838,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // No act-order case if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - int4 *sh_s_stage = + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } else { int warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; @@ -838,9 +854,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int k_blocks = cur_k / 16; int cur_group_id = k_blocks / group_blocks; - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + int4* sh_s_stage = sh_s + s_sh_stage * pipe; - reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; } } @@ -867,7 +883,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // thread-id) int warp_id = threadIdx.x / 32; int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N int warp_row = warp_id / n_warps; int warp_col = warp_id % n_warps; @@ -875,7 +891,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk cur_k += warp_row * 16; int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix int s_col_shift = /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + @@ -883,45 +899,44 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk if (is_same_group[pipe]) { if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); } for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); } return; } - int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread + 9}; // Tensor core offsets per thread -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; int group_id = sh_g_idx_int_ptr[actual_k]; int rel_group_id = group_id - sh_first_group_id; - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; } }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { -// We have the m dimension as the inner loop in order to encourage overlapping -// dequantization and matmul operations. -#pragma unroll + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll for (int j = 0; j < 4; j++) { FragB frag_b0; FragB frag_b1; @@ -933,7 +948,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk frag_b1 = dequant_4bit(b_quant_shift); } else { - int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; @@ -943,8 +958,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Apply scale to frag_b0 if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + scale4(frag_b0, act_frag_s[k % 2][0][j], + act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], + act_frag_s[k % 2][3][j], 0); } else { if constexpr (group_blocks != -1) { scale(frag_b0, frag_s[k % 2][j], 0); @@ -953,8 +969,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Apply scale to frag_b1 if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + scale4(frag_b1, act_frag_s[k % 2][0][j], + act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], + act_frag_s[k % 2][3][j], 1); } else { if constexpr (group_blocks != -1) { @@ -962,7 +979,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } } -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); @@ -987,38 +1004,38 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // unnecessary read or write iterations, e.g., for two warps we write only // once by warp 1 and read only once by warp 0. -#pragma unroll + #pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { -#pragma unroll + #pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float *c_rd = reinterpret_cast( - &sh[red_sh_delta * j + red_sh_rd]); - float *c_wr = reinterpret_cast(&sh[red_sh_wr]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { -#pragma unroll + #pragma unroll for (int i = 0; i < 4 * 2; i++) { - float *c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } @@ -1049,39 +1066,39 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int row = (threadIdx.x % 32) / 4; if (!first) { -// Interestingly, doing direct global accesses here really seems to mess up the -// compiler and lead to slowdowns, hence we also use async-copies even though -// these fetches are not actually asynchronous. -#pragma unroll + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + row < prob_m); + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); } cp_async_fence(); cp_async_wait<0>(); } -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; -#pragma unroll + #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( + reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); + Dtype::num2float(reinterpret_cast(&c_red)[j]); } } if (!last) { int4 c; -#pragma unroll + #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); } C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = @@ -1115,8 +1132,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS &s) { - scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); // For per-column quantization we finally apply the scale here (only for // 4-bit) @@ -1124,13 +1142,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk res = __hmul2(res, s[0]); } - ((scalar_t2 *)sh)[idx] = res; + ((scalar_t2*)sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4; j++) { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], @@ -1147,7 +1165,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } __syncthreads(); -#pragma unroll + #pragma unroll for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { @@ -1162,7 +1180,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Start global fetch and register load pipelines. auto start_pipes = [&]() { -#pragma unroll + #pragma unroll for (int i = 0; i < stages - 1; i++) { if (has_act_order && i == 0) { int last_g_idx = slice_k_start + stages * tb_k * 2; @@ -1193,9 +1211,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // have even length meaning that the next iteration will always start at // index 0. -#pragma unroll + #pragma unroll for (int pipe = 0; pipe < stages;) { -#pragma unroll + #pragma unroll for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); fetch_scales_to_registers(k + 1, pipe); @@ -1261,8 +1279,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } else { @@ -1270,8 +1288,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } } @@ -1282,31 +1300,35 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // overflow in fp16) if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { if (threadIdx.x / 32 < thread_n_blocks / 4) { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float( + reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); } } } } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); } - if (last) // only the last block in a slice actually writes the result + if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; slice_col_par++; @@ -1315,13 +1337,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; if (slice_col == 0) { -#pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] -= b_gl_stride; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } // Update slice k/n for scales loading @@ -1341,23 +1362,24 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } } -#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ - prob_k, locks); \ - } + #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin<<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ + prob_k, locks); \ + } typedef struct { int thread_k; @@ -1389,7 +1411,7 @@ thread_config_t large_batch_thread_configs[] = { }; -int get_scales_cache_size(thread_config_t const &th_config, int prob_m, +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full) { bool cache_scales_chunk = has_act_order && !is_k_full; @@ -1402,15 +1424,15 @@ int get_scales_cache_size(thread_config_t const &th_config, int prob_m, if (group_size == -1) { tb_groups = 1; } else if (group_size == 0) { - tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size } else { tb_groups = div_ceil(tb_k, group_size); } if (cache_scales_chunk) { int load_groups = - tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K - load_groups = max(load_groups, 32); // We load at least 32 scale groups + tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups return load_groups * tb_n * 2; } else { @@ -1420,7 +1442,7 @@ int get_scales_cache_size(thread_config_t const &th_config, int prob_m, } } -bool is_valid_cache_size(thread_config_t const &th_config, int max_m_blocks, +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int scales_cache_size, int max_shared_mem) { int pack_factor = 32 / num_bits; @@ -1451,12 +1473,12 @@ bool is_valid_cache_size(thread_config_t const &th_config, int max_m_blocks, float pipe_size = (a_size + b_size) * pipe_stages; - TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); } -bool is_valid_config(thread_config_t const &th_config, int max_m_blocks, +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, int max_shared_mem) { @@ -1519,43 +1541,43 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, } } - max_m_blocks--; // Process less M blocks per invocation to reduce cache - // usage + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage } return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) template -void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, - void *g_idx, void *perm, void *a_tmp, int prob_m, - int prob_n, int prob_k, void *workspace, int num_bits, +void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, + void* g_idx, void* perm, void* a_tmp, int prob_m, + int prob_n, int prob_k, void* workspace, int num_bits, bool has_act_order, bool is_k_full, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k, int thread_n, int sms, int max_par) { @@ -1639,15 +1661,15 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, } } - const int4 *A_ptr = (const int4 *)A; - const int4 *B_ptr = (const int4 *)B; - int4 *C_ptr = (int4 *)C; - const int4 *s_ptr = (const int4 *)s; - const int *g_idx_ptr = (const int *)g_idx; - const int *perm_ptr = (const int *)perm; - int4 *a_tmp_ptr = (int4 *)a_tmp; + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; + const int* g_idx_ptr = (const int*)g_idx; + const int* perm_ptr = (const int*)perm; + int4* a_tmp_ptr = (int4*)a_tmp; - int *locks = (int *)workspace; + int* locks = (int*)workspace; if (has_act_order) { // Permute A columns @@ -1673,8 +1695,7 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, // Note that parallel > 1 currently only works for inputs without any // padding par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); - if (par > max_par) - par = max_par; + if (par > max_par) par = max_par; prob_m = (16 * exec_cfg.max_m_blocks) * par; i += exec_cfg.max_m_blocks * (par - 1); thread_m_blocks = exec_cfg.max_m_blocks; @@ -1709,11 +1730,11 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, } } -} // namespace gptq_marlin +} // namespace gptq_marlin -torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &g_idx, - torch::Tensor &perm, torch::Tensor &workspace, +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full) { // Verify num_bits @@ -1824,18 +1845,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { gptq_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, gptq_marlin::max_par); + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), + a_tmp.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups, + group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + thread_n, sms, gptq_marlin::max_par); } else if (a.scalar_type() == at::ScalarType::BFloat16) { gptq_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, gptq_marlin::max_par); + a.data_ptr(), b_q_weight.data_ptr(), + c.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, + is_k_full, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + gptq_marlin::max_par); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cuh b/csrc/quantization/gptq_marlin/gptq_marlin.cuh index 35ea48aaba310..ba5368ea8835f 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cuh +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cuh @@ -11,22 +11,23 @@ namespace gptq_marlin { -// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per -// schedule allows some more latency hiding. At the same time, we want relatively few warps to have -// many registers per warp and small tiles. +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. static constexpr int default_threads = 256; -static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory +static constexpr int pipe_stages = + 4; // 4 pipeline stages fit into shared memory static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; static constexpr int tile_size = 16; -static constexpr int max_par = 16; +static constexpr int max_par = 16; template struct Vec { - T elems[n]; + T elems[n]; __device__ T& operator[](int i) { return elems[i]; } }; @@ -35,30 +36,35 @@ using I4 = Vec; constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - // No support for async +// No support for async #else -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); } __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); } -__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); } +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} template __device__ inline void cp_async_wait() { @@ -67,4 +73,4 @@ __device__ inline void cp_async_wait() { #endif -} // namespace gptq_marlin +} // namespace gptq_marlin diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh b/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh index 7881abbe4cbbf..ca1b7099d6ec7 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh +++ b/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh @@ -5,58 +5,73 @@ #include #include - namespace gptq_marlin { template -class ScalarType { -}; +class ScalarType {}; template <> class ScalarType { -public: - using scalar_t = half; - using scalar_t2 = half2; - - // Matrix fragments for tensor core instructions; their precise layout is - // documented here: - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type - using FragA = Vec; - using FragB = Vec; - using FragC = Vec; - using FragS = Vec; - - static __device__ float inline num2float(const half x) { return __half2float(x); } - - static __device__ half2 inline num2num2(const half x) { return __half2half2(x); } - - static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); } - - static __host__ __device__ half inline float2num(const float x) { return __float2half(x); } + public: + using scalar_t = half; + using scalar_t2 = half2; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + + static __device__ float inline num2float(const half x) { + return __half2float(x); + } + + static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); + } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ half inline float2num(const float x) { + return __float2half(x); + } }; template <> class ScalarType { -public: - using scalar_t = nv_bfloat16; - using scalar_t2 = nv_bfloat162; + public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; - using FragA = Vec; - using FragB = Vec; - using FragC = Vec; - using FragS = Vec; + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); } - - static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); } - - static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); } - - static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); } + static __device__ float inline num2float(const nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, + const nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } #endif }; -} +} // namespace gptq_marlin #endif diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index 0d3da6240dbca..4adc158eb14ea 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -12,14 +12,14 @@ static constexpr int tile_n_size = tile_k_size * 4; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 template -__global__ void -marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, - uint32_t const *__restrict__ perm_ptr, - uint32_t *__restrict__ out_ptr, int size_k, int size_n) {} +__global__ void marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) {} -} // namespace gptq_marlin +} // namespace gptq_marlin -torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) { TORCH_CHECK_NOT_IMPLEMENTED( @@ -30,10 +30,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, #else template -__global__ void -marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, - uint32_t const *__restrict__ perm_ptr, - uint32_t *__restrict__ out_ptr, int size_k, int size_n) { +__global__ void marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) { constexpr int pack_factor = 32 / num_bits; int k_tiles = size_k / tile_k_size; @@ -61,8 +61,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, constexpr int perm_size = tile_k_size / 4; - int4 *sh_perm_ptr = sh; - int4 *sh_pipe_ptr = sh_perm_ptr; + int4* sh_perm_ptr = sh; + int4* sh_pipe_ptr = sh_perm_ptr; if constexpr (has_perm) { sh_pipe_ptr += perm_size; } @@ -76,7 +76,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, auto load_perm_to_shared = [&](int k_tile_id) { int first_k_int4 = (k_tile_id * tile_k_size) / 4; - int4 const *perm_int4_ptr = reinterpret_cast(perm_ptr); + int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); if (threadIdx.x < perm_size) { sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; @@ -92,22 +92,22 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, int first_n = n_tile_id * tile_n_size; - int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe; + int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; if constexpr (has_perm) { if (threadIdx.x < stage_size) { int k_id = threadIdx.x / stage_n_threads; int n_id = threadIdx.x % stage_n_threads; - uint32_t const *sh_perm_int_ptr = - reinterpret_cast(sh_perm_ptr); + uint32_t const* sh_perm_int_ptr = + reinterpret_cast(sh_perm_ptr); int src_k = sh_perm_int_ptr[k_id]; int src_k_packed = src_k / pack_factor; cp_async4( &sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast(&( + reinterpret_cast(&( b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); } @@ -120,7 +120,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, int first_k_packed = first_k / pack_factor; cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast( + reinterpret_cast( &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)]))); } @@ -151,10 +151,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, constexpr int sh_stride = 64; constexpr uint32_t mask = (1 << num_bits) - 1; - int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; - uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); - uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); uint32_t vals[8]; @@ -176,17 +176,16 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, } } else { - uint32_t b1_vals[tile_ints]; uint32_t b2_vals[tile_ints]; -#pragma unroll + #pragma unroll for (int i = 0; i < tile_ints; i++) { b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; } -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { int cur_elem = tc_row + tc_offsets[i]; int cur_int = cur_elem / pack_factor; @@ -206,7 +205,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; uint32_t res = 0; -#pragma unroll + #pragma unroll for (int i = 0; i < 8; i++) { res |= vals[pack_idx[i]] << (i * 4); } @@ -218,7 +217,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, uint32_t res1 = 0; uint32_t res2 = 0; -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { res1 |= vals[pack_idx[i]] << (i * 8); res2 |= vals[4 + pack_idx[i]] << (i * 8); @@ -230,14 +229,14 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, }; auto start_pipes = [&](int k_tile_id, int n_tile_id) { -#pragma unroll + #pragma unroll for (int pipe = 0; pipe < repack_stages - 1; pipe++) { fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); } wait_for_stage(); }; -#pragma unroll + #pragma unroll for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { int n_tile_id = 0; @@ -248,7 +247,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, start_pipes(k_tile_id, n_tile_id); while (n_tile_id < n_tiles) { -#pragma unroll + #pragma unroll for (int pipe = 0; pipe < repack_stages; pipe++) { fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); @@ -260,21 +259,21 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, } } -} // namespace gptq_marlin - -#define CALL_IF(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ - cudaFuncSetAttribute( \ - gptq_marlin::marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - gptq_marlin::marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ - } +} // namespace gptq_marlin + + #define CALL_IF(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + gptq_marlin::marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + gptq_marlin::marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ + } -torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) { // Verify compatibility with marlin tile of 16x64 @@ -318,11 +317,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, bool has_perm = perm.size(0) != 0; // Get ptrs - uint32_t const *b_q_weight_ptr = - reinterpret_cast(b_q_weight.data_ptr()); - uint32_t const *perm_ptr = - reinterpret_cast(perm.data_ptr()); - uint32_t *out_ptr = reinterpret_cast(out.data_ptr()); + uint32_t const* b_q_weight_ptr = + reinterpret_cast(b_q_weight.data_ptr()); + uint32_t const* perm_ptr = reinterpret_cast(perm.data_ptr()); + uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); // Get dev info int dev = b_q_weight.get_device(); diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu index 002a70001885d..03d66cecedf1f 100644 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu @@ -25,7 +25,10 @@ #include -template inline std::string str(T x) { return std::to_string(x); } +template +inline std::string str(T x) { + return std::to_string(x); +} namespace marlin { @@ -38,9 +41,10 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } // corresponding index accesses must be compile-time constants, which is why we // extensively use `#pragma unroll` throughout the kernel code to guarantee // this. -template struct Vec { +template +struct Vec { T elems[n]; - __device__ T &operator[](int i) { return elems[i]; } + __device__ T& operator[](int i) { return elems[i]; } }; using I4 = Vec; @@ -51,29 +55,32 @@ using I4 = Vec; using FragA = Vec; using FragB = Vec; using FragC = Vec; -using FragS = Vec; // quantization scales +using FragS = Vec; // quantization scales // Predicated asynchronous global->shared copy; used for inputs A where we apply // predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); } // Asynchronous global->shared copy -__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); } // Async copy fence. @@ -82,28 +89,30 @@ __device__ inline void cp_async_fence() { } // Wait until at most `n` async copy stages are still pending. -template __device__ inline void cp_async_wait() { +template +__device__ inline void cp_async_wait() { asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); } // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. -__device__ inline void mma(const FragA &a_frag, const FragB &frag_b, - FragC &frag_c) { - const uint32_t *a = reinterpret_cast(&a_frag); - const uint32_t *b = reinterpret_cast(&frag_b); - float *c = reinterpret_cast(&frag_c); - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), - "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { - uint32_t *a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) @@ -113,7 +122,8 @@ __device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. -template __device__ inline int lop3(int a, int b, int c) { +template +__device__ inline int lop3(int a, int b, int c) { int res; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) @@ -138,24 +148,24 @@ __device__ inline FragB dequant(int q) { const int MUL = 0x2c002c00; const int ADD = 0xd480d480; FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } // Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int *lock, int count) { +__device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { int state = -1; do @@ -170,7 +180,7 @@ __device__ inline void barrier_acquire(int *lock, int count) { } // Release barrier and increment visitation count. -__device__ inline void barrier_release(int *lock, bool reset = false) { +__device__ inline void barrier_release(int* lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { @@ -187,26 +197,27 @@ __device__ inline void barrier_release(int *lock, bool reset = false) { } } -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void -Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 - *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * @@ -241,11 +252,11 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int slice_row = (iters * blockIdx.x) % k_tiles; int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice + int slice_iters; // number of threadblock tiles in the current slice int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers @@ -261,27 +272,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk auto init_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) - slice_iters = 0; - if (slice_iters == 0) - return; - if (slice_row + slice_iters > k_tiles) - slice_iters = k_tiles - slice_row; + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) - slice_count++; + if (col_off > 0) slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) - slice_idx--; + if (col_off > 0) slice_idx--; } } if (slice_col == n_tiles) { @@ -293,29 +299,30 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk }; init_slice(); - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory // We typically use `constexpr` to indicate that this value is a compile-time // constant constexpr int a_sh_stride = - 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / - 8; // delta between subsequent A tiles in global memory + 8; // delta between subsequent A tiles in global memory int a_gl_rd_delta_i = a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile constexpr int a_sh_wr_delta = - a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes + a_sh_stride * + (threads / a_gl_rd_delta_o); // between shared memory writes constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads + (thread_n_blocks / 4)); // between shared memory tile reads constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile + a_sh_stride * 16; // within a shared memory tile constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + a_sh_stride * (16 * thread_m_blocks); // overall size of a tile constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile + a_sh_wr_delta); // number of shared write iterations for a tile int b_gl_stride = 16 * prob_n / 32; constexpr int b_sh_stride = 32 * thread_n_blocks / 4; @@ -368,7 +375,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; @@ -387,13 +394,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < thread_m_blocks; j++) a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); @@ -403,16 +410,16 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4 *B_ptr[b_sh_wr_iters]; -#pragma unroll + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. - int4 *sh_a = sh; - int4 *sh_b = sh_a + (stages * a_sh_stage); - int4 *sh_s = sh_b + (stages * b_sh_stage); + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2]; @@ -421,34 +428,33 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Zero accumulators. auto zero_accums = [&]() { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; + reinterpret_cast(frag_c)[i] = 0; }; // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i]); } - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); B_ptr[i] += b_gl_rd_delta_o; } // Only fetch scales if this tile starts a new group if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) - cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); s_gl_rd += s_gl_rd_delta; } } @@ -475,37 +481,35 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // theoretically better attempts have lead to bad instruction ordering by // the compiler and correspondingly a noticeable drop in performance. if (group_blocks != -1) { - int4 *sh_s_stage = + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { -// We have the m dimension as the inner loop in order to encourage overlapping -// dequantization and matmul operations. -#pragma unroll + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll for (int j = 0; j < 4; j++) { int b_quant = frag_b_quant[k % 2][j]; int b_quant_shift = b_quant >> 8; FragB frag_b0 = dequant(b_quant); // If there are no groups, we can just scale the final output once and can // avoid doing so for each weight. - if (group_blocks != -1) - scale(frag_b0, frag_s[k % 2][j], 0); + if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); FragB frag_b1 = dequant(b_quant_shift); - if (group_blocks != -1) - scale(frag_b1, frag_s[k % 2][j], 1); -#pragma unroll + if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); @@ -530,38 +534,38 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // unnecessary read or write iterations, e.g., for two warps we write only // once by warp 1 and read only once by warp 0. -#pragma unroll + #pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { -#pragma unroll + #pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float *c_rd = reinterpret_cast( - &sh[red_sh_delta * j + red_sh_rd]); - float *c_wr = reinterpret_cast(&sh[red_sh_wr]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { -#pragma unroll + #pragma unroll for (int i = 0; i < 4 * 2; i++) { - float *c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } @@ -571,9 +575,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk }; // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped partitioning - // minimizes the number of such reductions and our outputs are usually rather - // small, we perform this reduction serially in L2 cache. + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out @@ -592,39 +596,39 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int row = (threadIdx.x % 32) / 4; if (!first) { -// Interestingly, doing direct global accesses here really seems to mess up the -// compiler and lead to slowdowns, hence we also use async-copies even though -// these fetches are not actually asynchronous. -#pragma unroll + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + row < prob_m); + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); } cp_async_fence(); cp_async_wait<0>(); } -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; -#pragma unroll + #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( + reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half *>(&c_red)[j]); + __half2float(reinterpret_cast<__half*>(&c_red)[j]); } } if (!last) { int4 c; -#pragma unroll + #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half *>(&c)[j] = - __float2half(reinterpret_cast( + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); } C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = @@ -658,17 +662,17 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS &s) { + auto write = [&](int idx, float c0, float c1, FragS& s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); if (group_blocks == - -1) // for per-column quantization we finally apply the scale here + -1) // for per-column quantization we finally apply the scale here res = __hmul2(res, s[0]); - ((half2 *)sh)[idx] = res; + ((half2*)sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4; j++) { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], @@ -685,7 +689,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } __syncthreads(); -#pragma unroll + #pragma unroll for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { @@ -699,9 +703,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Start global fetch and register load pipelines. auto start_pipes = [&]() { -#pragma unroll - for (int i = 0; i < stages - 1; i++) - fetch_to_shared(i, i, i < slice_iters); + #pragma unroll + for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); zero_accums(); wait_for_stage(); fetch_to_registers(0, 0); @@ -711,12 +714,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Main loop. while (slice_iters) { -// We unroll over both the global fetch and the register load pipeline to ensure -// all shared memory accesses are static. Note that both pipelines have even -// length meaning that the next iteration will always start at index 0. -#pragma unroll + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines have + // even length meaning that the next iteration will always start at index 0. + #pragma unroll for (int pipe = 0; pipe < stages;) { -#pragma unroll + #pragma unroll for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); if (k == b_sh_wr_iters - 2) { @@ -728,8 +731,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk matmul(k); } slice_iters--; - if (slice_iters == 0) - break; + if (slice_iters == 0) break; } a_gl_rd += a_gl_rd_delta_o * stages; @@ -742,8 +744,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // For per-column scales, we only fetch them here in the final step before // write-out if (group_blocks == -1 && last) { - if (s_sh_wr_pred) - cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } thread_block_reduce(); @@ -751,17 +752,17 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); } - if (last) // only the last block in a slice actually writes the result + if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; slice_col_par++; @@ -770,13 +771,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; if (slice_col == 0) { -#pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] -= b_gl_stride; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } s_gl_rd = s_sh_stride * slice_col + threadIdx.x; start_pipes(); @@ -787,26 +787,27 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk #else -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void -Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 - *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) { // Marlin is not implemented yet for SM < 8.0 assert(false); @@ -819,10 +820,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory const int SHARED_MEM = - 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; @@ -831,7 +832,7 @@ static constexpr int tile_size = 16; static constexpr int max_par = 16; static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit + 8; // We have 8 4-bit vals inside a 32 bit #define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ GROUP_BLOCKS, NUM_THREADS) \ @@ -858,23 +859,23 @@ thread_config_t small_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N }; thread_config_t large_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X + {64, 256, 256}, // Default + {128, 128, 256}, // Reduce N 2X, increase K 2X + {64, 128, 128}, // Reduce N 2X, same K + {128, 64, 128}, // Reduce N 4X, increase K 2X }; -bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, +bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || @@ -907,7 +908,6 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, } thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { for (auto th_config : small_batch_thread_configs) { if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { @@ -926,20 +926,20 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { return thread_config_t{-1, -1, -1}; } -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ +#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) -void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, - int prob_n, int prob_k, void *workspace, int groupsize = -1, +void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, + int prob_n, int prob_k, void* workspace, int groupsize = -1, int dev = 0, cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1, int sms = -1, int max_par = 16) { int tot_m = prob_m; @@ -996,12 +996,12 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, " is not divisible by group_blocks = ", group_blocks); } - const int4 *A_ptr = (const int4 *)A; - const int4 *B_ptr = (const int4 *)B; - int4 *C_ptr = (int4 *)C; - const int4 *s_ptr = (const int4 *)s; + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; - int *locks = (int *)workspace; + int* locks = (int*)workspace; for (int i = 0; i < tot_m_blocks; i += 4) { int thread_m_blocks = tot_m_blocks - i; @@ -1011,8 +1011,7 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, // Note that parallel > 1 currently only works for inputs without any // padding par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) - par = max_par; + if (par > max_par) par = max_par; prob_m = 64 * par; i += 4 * (par - 1); thread_m_blocks = 4; @@ -1041,12 +1040,11 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, } } -} // namespace marlin +} // namespace marlin -torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &workspace, +torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k) { - // Verify M TORCH_CHECK(size_m == a.size(0), "Shape mismatch: a.size(0) = " + str(a.size(0)) + @@ -1074,9 +1072,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit; - TORCH_CHECK(size_n == actual_size_n, - "size_n = " + str(size_n) + - ", actual_size_n = " + str(actual_size_n)); + TORCH_CHECK( + size_n == actual_size_n, + "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); // Verify A device and strides TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); diff --git a/csrc/quantization/marlin/sparse/common/base.h b/csrc/quantization/marlin/sparse/common/base.h index 929b39d7642f1..16018d331bec2 100644 --- a/csrc/quantization/marlin/sparse/common/base.h +++ b/csrc/quantization/marlin/sparse/common/base.h @@ -26,12 +26,14 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } // corresponding index accesses must be compile-time constants, which is why we // extensively use `#pragma unroll` throughout the kernel code to guarantee // this. -template struct Vec { +template +struct Vec { T elems[n]; - __device__ T &operator[](int i) { return elems[i]; } + __device__ T& operator[](int i) { return elems[i]; } }; -template struct ShapeBase { +template +struct ShapeBase { static constexpr int M = M_, N = N_, K = K_; }; @@ -44,6 +46,6 @@ using FragA = Vec; using FragB = Vec; using FragM = Vec; using FragC = Vec; -using FragS = Vec; // quantization scales +using FragS = Vec; // quantization scales -} // namespace marlin_24 +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/common/mem.h b/csrc/quantization/marlin/sparse/common/mem.h index a49d15ca544eb..83e3578d2f511 100644 --- a/csrc/quantization/marlin/sparse/common/mem.h +++ b/csrc/quantization/marlin/sparse/common/mem.h @@ -21,41 +21,44 @@ namespace marlin_24 { // Predicated asynchronous global->shared copy; used for inputs A where we apply // predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred_zfill(void *smem_ptr, - const void *glob_ptr, +__device__ inline void cp_async4_pred_zfill(void* smem_ptr, + const void* glob_ptr, bool pred = true, const bool zfill = false) { const int BYTES = 16; int src_in_bytes = (zfill ? 0 : BYTES); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); } -__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); } // Asynchronous global->shared copy -__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); } // Async copy fence. @@ -64,22 +67,23 @@ __device__ inline void cp_async_fence() { } // Wait until at most `n` async copy stages are still pending. -template __device__ inline void cp_async_wait() { +template +__device__ inline void cp_async_wait() { asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { - uint32_t *a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)); } -__device__ inline void ldsm4_m(FragM &frag_m, const void *smem_ptr) { - uint32_t *a = reinterpret_cast(&frag_m); +__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_m); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) @@ -88,8 +92,8 @@ __device__ inline void ldsm4_m(FragM &frag_m, const void *smem_ptr) { // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -__device__ inline void ldsm4_t(FragA &frag_a, const void *smem_ptr) { - uint32_t *a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" @@ -98,7 +102,7 @@ __device__ inline void ldsm4_t(FragA &frag_a, const void *smem_ptr) { } // Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int *lock, int count) { +__device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { int state = -1; do @@ -113,7 +117,7 @@ __device__ inline void barrier_acquire(int *lock, int count) { } // Release barrier and increment visitation count. -__device__ inline void barrier_release(int *lock, bool reset = false) { +__device__ inline void barrier_release(int* lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { @@ -129,4 +133,4 @@ __device__ inline void barrier_release(int *lock, bool reset = false) { : "l"(lock), "r"(val)); } } -} // namespace marlin_24 +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/common/mma.h b/csrc/quantization/marlin/sparse/common/mma.h index 9319456677d36..45ab67a78a1de 100644 --- a/csrc/quantization/marlin/sparse/common/mma.h +++ b/csrc/quantization/marlin/sparse/common/mma.h @@ -22,51 +22,56 @@ namespace marlin_24 { // m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. -__device__ inline void mma_sp(const FragB &a_frag0, const FragB &a_frag1, - const FragA &frag_b, FragC &frag_c, FragM &frag_m, +__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, + const FragA& frag_b, FragC& frag_c, FragM& frag_m, const int psel) { - const uint32_t *a0 = reinterpret_cast(&a_frag0); - const uint32_t *a1 = reinterpret_cast(&a_frag1); - const uint32_t *b = reinterpret_cast(&frag_b); - const uint32_t *e = reinterpret_cast(&frag_m); - float *c = reinterpret_cast(&frag_c); + const uint32_t* a0 = reinterpret_cast(&a_frag0); + const uint32_t* a1 = reinterpret_cast(&a_frag1); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* e = reinterpret_cast(&frag_m); + float* c = reinterpret_cast(&frag_c); if (psel == 0) { - asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), - "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), - "f"(c[2]), "f"(c[3]), "r"(e[0])); - asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), - "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), - "f"(c[6]), "f"(c[7]), "r"(e[0])); + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]), + "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]), + "r"(e[0])); + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]), + "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]), + "r"(e[0])); } else { - asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), - "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), - "f"(c[2]), "f"(c[3]), "r"(e[0])); - asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), - "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), - "f"(c[6]), "f"(c[7]), "r"(e[0])); + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]), + "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]), + "r"(e[0])); + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]), + "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]), + "r"(e[0])); } } // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. -template __device__ inline int lop3(int a, int b, int c) { +template +__device__ inline int lop3(int a, int b, int c) { int res; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) @@ -120,11 +125,11 @@ __device__ inline FragB dequant_4bit(int q) { const int ADD = 0xd480d480; FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } @@ -143,24 +148,24 @@ __device__ inline FragB dequant_8bit(int q) { static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); return frag_b; } // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } -__device__ inline void scale_floats(float *c0, float *c1, float *c2, float *c3, - FragS &s0, float *c4, float *c5, float *c6, - float *c7, FragS &s1) { +__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, + FragS& s0, float* c4, float* c5, float* c6, + float* c7, FragS& s1) { *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); *c1 = __fmul_rn(*c1, __half2float(s0[0].y)); *c2 = __fmul_rn(*c2, __half2float(s0[1].x)); @@ -172,4 +177,4 @@ __device__ inline void scale_floats(float *c0, float *c1, float *c2, float *c3, *c7 = __fmul_rn(*c7, __half2float(s1[1].y)); } -} // namespace marlin_24 +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 42b0566183a8d..54ad27676e207 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -32,12 +32,15 @@ #else -#include "common/mem.h" -#include "common/mma.h" + #include "common/mem.h" + #include "common/mma.h" #endif -template inline std::string str(T x) { return std::to_string(x); } +template +inline std::string str(T x) { + return std::to_string(x); +} namespace marlin_24 { @@ -45,7 +48,7 @@ namespace marlin_24 { // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. static constexpr int THREADS = 256; -static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory +static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory static constexpr int min_thread_n = 128; @@ -54,35 +57,36 @@ static constexpr int max_par = 16; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > __global__ void Marlin_24( - const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4 - *__restrict__ meta, // 2bit metadata information about 2:4 format on B - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 - *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4* __restrict__ meta, // 2bit metadata information about 2:4 + // format on B + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) {} -torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_meta, - torch::Tensor &b_scales, - torch::Tensor &workspace, int64_t num_bits, +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k) { TORCH_CHECK_NOT_IMPLEMENTED( @@ -92,29 +96,30 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, #else -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > __global__ void Marlin_24( - const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4 - *__restrict__ meta, // 2bit metadata information about 2:4 format on B - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 - *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4* __restrict__ meta, // 2bit metadata information about 2:4 + // format on B + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * @@ -174,27 +179,22 @@ __global__ void Marlin_24( auto init_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) - slice_iters = 0; - if (slice_iters == 0) - return; - if (slice_row + slice_iters > k_tiles) - slice_iters = k_tiles - slice_row; + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) - slice_count++; + if (col_off > 0) slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) - slice_idx--; + if (col_off > 0) slice_idx--; } } if (slice_col == n_tiles) { @@ -207,7 +207,7 @@ __global__ void Marlin_24( init_slice(); // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory // stride of an A matrix tile in shared memory constexpr int a_sh_stride = 32 * thread_k_blocks / 8; @@ -239,9 +239,9 @@ __global__ void Marlin_24( constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 + int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 constexpr int m_sh_stride = - (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp + (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); constexpr int m_sh_wr_delta = threads / 2; @@ -305,7 +305,7 @@ __global__ void Marlin_24( // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; } @@ -325,13 +325,13 @@ __global__ void Marlin_24( // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < thread_m_blocks; j++) { a_sh_rd_trans[0][i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); @@ -344,23 +344,23 @@ __global__ void Marlin_24( // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4 *B_ptr[b_sh_wr_iters]; -#pragma unroll + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; - const int4 *meta_ptr[m_sh_iters]; -#pragma unroll + const int4* meta_ptr[m_sh_iters]; + #pragma unroll for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. - int4 *sh_a = sh; - int4 *sh_b = sh_a + (stages * a_sh_stage); - int4 *sh_s = sh_b + (stages * b_sh_stage); - int4 *sh_m = sh_s + (stages * s_sh_stage); + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); + int4* sh_m = sh_s + (stages * s_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks][2]; I4 frag_b_quant[2][b_thread_vecs]; @@ -370,46 +370,43 @@ __global__ void Marlin_24( // Zero accumulators. auto zero_accums = [&]() { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; + reinterpret_cast(frag_c)[i] = 0; }; // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i]); } - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], - B_ptr[i] + j); + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); } B_ptr[i] += b_gl_rd_delta_o; } - int4 *sh_meta_stage = sh_m + m_sh_stage * pipe; -#pragma unroll + int4* sh_meta_stage = sh_m + m_sh_stage * pipe; + #pragma unroll for (int i = 0; i < m_sh_iters; i++) { if (m_sh_wr_pred) - cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], - meta_ptr[i]); + cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]); meta_ptr[i] += m_gl_rd_delta_o; } // Only fetch scales if this tile starts a new group if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) - cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); s_gl_rd += s_gl_rd_delta; } } @@ -436,13 +433,13 @@ __global__ void Marlin_24( // theoretically better attempts have lead to bad instruction ordering by // the compiler and correspondingly a noticeable drop in performance. if (group_blocks != -1) { - int4 *sh_s_stage = + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { ldsm4(frag_a[k % 2][i][0], &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); @@ -450,24 +447,24 @@ __global__ void Marlin_24( &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); } - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( + frag_b_quant[k % 2][i] = *reinterpret_cast( &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); } // Load meta with ldsm4 - int4 *sh_m_stage = sh_m + m_sh_stage * pipe; + int4* sh_m_stage = sh_m + m_sh_stage * pipe; ldsm4_m(frag_m[k % 2][0], &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { -// We have the m dimension as the inner loop in order to encourage overlapping -// dequantization and matmul operations. -#pragma unroll + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll for (int j = 0; j < 4; j++) { FragB frag_b0; FragB frag_b1; @@ -480,7 +477,7 @@ __global__ void Marlin_24( frag_b1 = dequant_4bit(b_quant_shift); } else { - int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; @@ -497,7 +494,7 @@ __global__ void Marlin_24( scale(frag_b1, frag_s[k % 2][j], 1); } -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], frag_m[k % 2][j / 2], j % 2); @@ -518,41 +515,41 @@ __global__ void Marlin_24( int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); -// Parallel logarithmic shared memory reduction. We make sure to avoid any -// unnecessary read or write iterations, e.g., for two warps we write only once -// by warp 1 and read only once by warp 0. -#pragma unroll + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + #pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { -#pragma unroll + #pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float *c_rd = reinterpret_cast( - &sh[red_sh_delta * j + red_sh_rd]); - float *c_wr = reinterpret_cast(&sh[red_sh_wr]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { -#pragma unroll + #pragma unroll for (int i = 0; i < 4 * 2; i++) { - float *c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } @@ -562,9 +559,9 @@ __global__ void Marlin_24( }; // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped partitioning - // minimizes the number of such reductions and our outputs are usually rather - // small, we perform this reduction serially in L2 cache. + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out @@ -574,7 +571,7 @@ __global__ void Marlin_24( int c_gl_stride = prob_n / 8; int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; int c_gl_wr_delta_i = - c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) + c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; c_gl_wr += (2 * thread_n_blocks) * slice_col; @@ -584,10 +581,10 @@ __global__ void Marlin_24( int col = 2 * ((threadIdx.x % 32) % 4); if (!first) { -// Interestingly, doing direct global accesses here really seems to mess up the -// compiler and lead to slowdowns, hence we also use async-copies even though -// these fetches are not actually asynchronous. -#pragma unroll + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + @@ -599,32 +596,32 @@ __global__ void Marlin_24( cp_async_wait<0>(); } -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + col + (i % 2) < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; -#pragma unroll + #pragma unroll for (int j2 = 0; j2 < 2; j2++) { -#pragma unroll + #pragma unroll for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast( + reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + 4 * ((i % 4) / 2) + i % 2] += __half2float( - reinterpret_cast<__half *>(&c_red)[(j2 * 4 + j1)]); + reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]); } } } if (!last) { int4 c; -#pragma unroll + #pragma unroll for (int j2 = 0; j2 < 2; j2++) { -#pragma unroll + #pragma unroll for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast<__half *>(&c)[(j2 * 4 + j1)] = - __float2half(reinterpret_cast( + reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] = + __float2half(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + 4 * ((i % 4) / 2) + i % 2]); } @@ -643,9 +640,9 @@ __global__ void Marlin_24( auto write_result = [&]() { int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: - constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: - constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: + constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: + constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: + constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); @@ -654,22 +651,22 @@ __global__ void Marlin_24( c_gl_wr += (2 * thread_n_blocks) * slice_col; int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + - ((threadIdx.x % 32) / 4); // RLC: - c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) + ((threadIdx.x % 32) / 4); // RLC: + c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) constexpr int c_sh_rd_delta = - c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: + c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + (threadIdx.x % (2 * 2 * thread_n_blocks)); int c_gl_wr_end = c_gl_stride * prob_m; - auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS &s0, - float c4, float c5, float c6, float c7, FragS &s1) { + auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0, + float c4, float c5, float c6, float c7, FragS& s1) { uint2 res[2]; res[0] = to_half4(c0, c1, c2, c3); res[1] = to_half4(c4, c5, c6, c7); - half2 *tmp = (half2 *)&res; + half2* tmp = (half2*)&res; // for per-column quantization we finally apply the scale here if constexpr (group_blocks == -1 && num_bits == 4) { tmp[0] = __hmul2(tmp[0], s0[0]); @@ -677,12 +674,12 @@ __global__ void Marlin_24( tmp[2] = __hmul2(tmp[2], s1[0]); tmp[3] = __hmul2(tmp[3], s1[1]); } - ((int4 *)sh)[idx] = *((int4 *)&res[0]); + ((int4*)sh)[idx] = *((int4*)&res[0]); }; // RLC: only warp 0 and 1 baseline example if (threadIdx.x / 32 < thread_n_blocks / 4) { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { int wr = c_sh_wr; write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], @@ -707,7 +704,7 @@ __global__ void Marlin_24( } __syncthreads(); -#pragma unroll + #pragma unroll for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { @@ -721,9 +718,8 @@ __global__ void Marlin_24( // Start global fetch and register load pipelines. auto start_pipes = [&]() { -#pragma unroll - for (int i = 0; i < stages - 1; i++) - fetch_to_shared(i, i, i < slice_iters); + #pragma unroll + for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); zero_accums(); wait_for_stage(); fetch_to_registers(0, 0); @@ -733,10 +729,10 @@ __global__ void Marlin_24( // Main loop. while (slice_iters) { -// We unroll over both the global fetch and the register load pipeline to ensure -// all shared memory accesses are static. Note that both pipelines have even -// length meaning that the next iteration will always start at index 0. -#pragma unroll + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines have + // even length meaning that the next iteration will always start at index 0. + #pragma unroll for (int pipe = 0; pipe < stages;) { fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); @@ -747,8 +743,7 @@ __global__ void Marlin_24( pipe++; slice_iters--; - if (slice_iters == 0) - break; + if (slice_iters == 0) break; } a_gl_rd += a_gl_rd_delta_o * stages; @@ -762,13 +757,11 @@ __global__ void Marlin_24( // write-out if constexpr (group_blocks == -1) { if constexpr (num_bits == 8) { - if (s_sh_wr_pred) - cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } else { if (last) { - if (s_sh_wr_pred) - cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } } @@ -780,14 +773,14 @@ __global__ void Marlin_24( cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); + *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); } } else { if (last) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); + *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); } } } @@ -798,7 +791,7 @@ __global__ void Marlin_24( // overflow in fp16) if constexpr (group_blocks == -1 && num_bits == 8) { if (threadIdx.x / 32 < thread_n_blocks / 4) { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0], &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0], @@ -827,13 +820,13 @@ __global__ void Marlin_24( } } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); } - if (last) // only the last block in a slice actually writes the result + if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; @@ -843,19 +836,17 @@ __global__ void Marlin_24( if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; -#pragma unroll + #pragma unroll for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; if (slice_col == 0) { -#pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] -= b_gl_stride; -#pragma unroll - for (int i = 0; i < m_sh_iters; i++) - meta_ptr[i] -= m_gl_stride; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + #pragma unroll + for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride; } s_gl_rd = s_sh_stride * slice_col + threadIdx.x; start_pipes(); @@ -866,26 +857,26 @@ __global__ void Marlin_24( #endif -#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, GROUP_BLOCKS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS) { \ - cudaFuncSetAttribute( \ - Marlin_24, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin_24 \ - <<>>(A_ptr, B_ptr, meta_ptr, \ - C_ptr, s_ptr, prob_n, \ - prob_m, prob_k, locks); \ +#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS) { \ + cudaFuncSetAttribute( \ + Marlin_24, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin_24 \ + <<>>(A_ptr, B_ptr, meta_ptr, \ + C_ptr, s_ptr, prob_n, \ + prob_m, prob_k, locks); \ } -void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, - void *s, int prob_m, int prob_n, int prob_k, - void *workspace, int num_bits, int groupsize = -1, +void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, + void* s, int prob_m, int prob_n, int prob_k, + void* workspace, int num_bits, int groupsize = -1, int dev = 0, cudaStream_t stream = 0, int thread_k = -1, int thread_m = -1, int sms = -1, int max_par = 16) { int tot_n = prob_n; @@ -904,8 +895,8 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, if (thread_k == -1 || thread_m == -1) { if (prob_n <= 16) { - // For small batchizes, better partitioningif is slightly more important than - // better compute utilization + // For small batchizes, better partitioningif is slightly more important + // than better compute utilization thread_k = 128; thread_m = 128; } else { @@ -914,7 +905,7 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, } } - int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction + int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction int thread_m_blocks = thread_m / 16; int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; int blocks = sms; @@ -931,13 +922,13 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); - const int4 *A_ptr = (const int4 *)A; - const int4 *B_ptr = (const int4 *)B; - const int4 *meta_ptr = (const int4 *)meta; - int4 *C_ptr = (int4 *)C; - const int4 *s_ptr = (const int4 *)s; + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + const int4* meta_ptr = (const int4*)meta; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; - int *locks = (int *)workspace; + int* locks = (int*)workspace; for (int i = 0; i < tot_n_blocks; i += 4) { int thread_n_blocks = tot_n_blocks - i; prob_n = tot_n - 16 * i; @@ -946,8 +937,7 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, // Note that parallel > 1 currently only works for inputs without any // padding par = (16 * thread_n_blocks - pad) / 64; - if (par > max_par) - par = max_par; + if (par > max_par) par = max_par; prob_n = 64 * par; i += 4 * (par - 1); thread_n_blocks = 4; @@ -956,16 +946,16 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, // For compilation speed, we only define the kernel configurations that have // seemed useful (in terms of performance) in our testing, however many more // are, in principle, possible. - + // the false is start of the CALL_IF macros - if (false) { - } // BMxBNxBK, group + if (false) { + } // BMxBNxBK, group // 4-bit - CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 - CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 - CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 CALL_IF_2_4(4, 16, 2, 2, 4) CALL_IF_2_4(4, 16, 3, 2, -1) CALL_IF_2_4(4, 16, 3, 2, 4) @@ -973,11 +963,11 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, CALL_IF_2_4(4, 16, 4, 2, 4) // 8-bit - CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 - CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 - CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 CALL_IF_2_4(8, 16, 2, 2, 4) CALL_IF_2_4(8, 16, 3, 2, -1) CALL_IF_2_4(8, 16, 3, 2, 4) @@ -997,12 +987,12 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, } } -} // namespace marlin_24 +} // namespace marlin_24 -torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_meta, - torch::Tensor &b_scales, - torch::Tensor &workspace, int64_t num_bits, +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k) { // Verify num_bits @@ -1037,9 +1027,9 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, " is not divisible by tile_size = " + str(marlin_24::tile_size)); int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor; - TORCH_CHECK(size_n == actual_size_n, - "size_n = " + str(size_n) + - ", actual_size_n = " + str(actual_size_n)); + TORCH_CHECK( + size_n == actual_size_n, + "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); // Verify meta TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2, @@ -1081,7 +1071,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, ", is not divisible by b_scales.size(0) = " + str(b_scales.size(0))); groupsize = size_k / b_scales.size(0); - groupsize /= 2; // Because of 24 + groupsize /= 2; // Because of 24 } // Verify groupsize diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index 09964903622b4..1b339fa4b392b 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -22,27 +22,23 @@ __device__ inline unsigned int as_unsigned(int i) { // 4-bit matvec kernel (LUT-based) __global__ void NUQ4MatMulKernel( #ifndef USE_ROCM - const half2* __restrict__ vec, + const half2* __restrict__ vec, #else - const __half2* __restrict__ vec, + const __half2* __restrict__ vec, #endif - const int* __restrict__ mat, + const int* __restrict__ mat, #ifndef USE_ROCM - half2* __restrict__ mul, + half2* __restrict__ mul, #else - float2* __restrict__ mul, + float2* __restrict__ mul, #endif - const __half* __restrict__ lookup_table, - int height, - int width, - int batch, - int vec_height -) { + const __half* __restrict__ lookup_table, int height, int width, int batch, + int vec_height) { const int blockwidth2 = BLOCKWIDTH / 2; int row = BLOCKHEIGHT4 * blockIdx.x; - int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; + int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; #ifndef USE_ROCM __shared__ half2 blockvec[blockwidth2]; @@ -73,14 +69,16 @@ __global__ void NUQ4MatMulKernel( unsigned int tmp1; unsigned int lut_index1, lut_index2; - for (int b = 0; b < batch; ++b){ + for (int b = 0; b < batch; ++b) { i = width * row + col; res = __int2half_rd(0); k = 0; __syncthreads(); if (threadIdx.x < blockwidth2) - blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x]; + blockvec[threadIdx.x] = + vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + + threadIdx.x]; __syncthreads(); while (k < blockwidth2) { @@ -143,7 +141,8 @@ __global__ void NUQ4MatMulKernel( #ifndef USE_ROCM res = __hadd(__hadd(res2.x, res2.y), res); #else - res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res); + res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), + res); #endif i += width; @@ -179,46 +178,38 @@ __global__ void NUQ4MatMulKernel( } } -} // namespace squeezellm -} // namespace vllm +} // namespace squeezellm +} // namespace vllm // 4-bit matvec kernel (LUT-based) -void squeezellm_gemm( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor lookup_table -) { +void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table) { int height = mat.size(0); int width = mat.size(1); int batch = vec.size(0); int vec_height = vec.size(1); - dim3 blocks( - (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, - (width + BLOCKWIDTH - 1) / BLOCKWIDTH - ); + dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH); dim3 threads(BLOCKWIDTH); const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); vllm::squeezellm::NUQ4MatMulKernel<<>>( #ifndef USE_ROCM - (half2*) vec.data(), + (half2*)vec.data(), #else - (__half2*) vec.data_ptr(), + (__half2*)vec.data_ptr(), #endif - mat.data_ptr(), + mat.data_ptr(), #ifndef USE_ROCM - (half2*) mul.data(), - (__half*) lookup_table.data(), + (half2*)mul.data(), (__half*)lookup_table.data(), #else - (float2*) mul.data_ptr(), - (__half*) lookup_table.data_ptr(), + (float2*)mul.data_ptr(), + (__half*)lookup_table.data_ptr(), #endif - height, width, batch, vec_height - ); + height, width, batch, vec_height); } #undef BLOCKWIDTH diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index bb5171f854d55..9af4aae516151 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -20,12 +21,12 @@ #include "cuda_compat.h" namespace vllm { -template +template __inline__ __device__ T warpReduceSum(T val) { static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0, "numLanes is not a positive power of 2!"); static_assert(numLanes <= WARP_SIZE); - #pragma unroll +#pragma unroll for (int mask = numLanes >> 1; mask > 0; mask >>= 1) val += VLLM_SHFL_XOR_SYNC(val, mask); return val; @@ -38,22 +39,23 @@ static constexpr int _nextPow2(unsigned int num) { } /* Calculate the sum of all elements in a block */ -template +template __inline__ __device__ T blockReduceSum(T val) { static_assert(maxBlockSize <= 1024); if constexpr (maxBlockSize > WARP_SIZE) { val = warpReduceSum(val); - // Calculates max number of lanes that need to participate in the last warpReduce + // Calculates max number of lanes that need to participate in the last + // warpReduce constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE; static __shared__ T shared[maxActiveLanes]; int lane = threadIdx.x % WARP_SIZE; int wid = threadIdx.x / WARP_SIZE; - if (lane == 0) - shared[wid] = val; + if (lane == 0) shared[wid] = val; __syncthreads(); - val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f); + val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] + : (T)(0.0f); val = warpReduceSum(val); } else { // A single warpReduce is equal to blockReduce @@ -62,4 +64,4 @@ __inline__ __device__ T blockReduceSum(T val) { return val; } -} // namespace vllm +} // namespace vllm diff --git a/format.sh b/format.sh index 5f6e20256d404..aaec25a8aa0dc 100755 --- a/format.sh +++ b/format.sh @@ -26,6 +26,7 @@ RUFF_VERSION=$(ruff --version | awk '{print $2}') MYPY_VERSION=$(mypy --version | awk '{print $2}') CODESPELL_VERSION=$(codespell --version) ISORT_VERSION=$(isort --vn) +CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}') # # params: tool name, tool version, required version tool_version_check() { @@ -40,6 +41,7 @@ tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)" tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)" tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)" +tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)" YAPF_FLAGS=( '--recursive' @@ -179,7 +181,6 @@ lint_changed() { } # Run Ruff -echo 'vLLM ruff:' ### This flag lints individual files. --files *must* be the first command line ### arg to use this option. if [[ "$1" == '--files' ]]; then @@ -192,6 +193,7 @@ else # Format only the files that changed in last commit. lint_changed fi +echo 'vLLM ruff: Done' # check spelling of specified files isort_check() { @@ -233,6 +235,59 @@ else fi echo 'vLLM isort: Done' +# Clang-format section +# Exclude some files for formatting because they are vendored +# NOTE: Keep up to date with .github/workflows/clang-format.yml +CLANG_FORMAT_EXCLUDES=( + 'csrc/moe/topk_softmax_kernels.cu' + 'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' + 'csrc/punica/bgmv/bgmv_config.h' + 'csrc/punica/bgmv/bgmv_impl.cuh' + 'csrc/punica/bgmv/vec_dtypes.cuh' + 'csrc/punica/punica_ops.cu' + 'csrc/punica/type_convert.h' +) + +# Format specified files with clang-format +clang_format() { + clang-format -i "$@" +} + +# Format files that differ from main branch with clang-format. +clang_format_changed() { + # The `if` guard ensures that the list of filenames is not empty, which + # could cause clang-format to receive 0 positional arguments, making it hang + # waiting for STDIN. + # + # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that + # exist on both branches. + MERGEBASE="$(git merge-base origin/main HEAD)" + + # Get the list of changed files, excluding the specified ones + changed_files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.h' '*.cpp' '*.cu' '*.cuh' | grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}")) + if [ -n "$changed_files" ]; then + echo "$changed_files" | xargs -P 5 clang-format -i + fi +} + +# Format all files with clang-format +clang_format_all() { + find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ + | grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") \ + | xargs clang-format -i +} + +# Run clang-format +if [[ "$1" == '--files' ]]; then + clang_format "${@:2}" +elif [[ "$1" == '--all' ]]; then + clang_format_all +else + clang_format_changed +fi +echo 'vLLM clang-format: Done' + + if ! git diff --quiet &>/dev/null; then echo 'Reformatted files. Please review and stage the changes.' echo 'Changes not staged for commit:' diff --git a/requirements-dev.txt b/requirements-dev.txt index 4f6c27d95fe6a..cf2bb9bef22d9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,6 +5,7 @@ tomli==2.0.1 ruff==0.1.5 codespell==2.2.6 isort==5.13.2 +clang-format==18.1.5 # type checking mypy==1.9.0 From c74c913bfbefc5d7a1302557eb35cdcbecd91f67 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Wed, 22 May 2024 22:02:58 +0900 Subject: [PATCH 096/124] [misc] remove comments that were supposed to be removed (#4977) --- tests/lora/conftest.py | 1 - vllm/lora/models.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 95fc65cdd1a8f..e5cf9cd48b65d 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -185,7 +185,6 @@ def long_context_lora_files_32k(): return snapshot_download(repo_id="SangBinCho/long_context_32k_testing") -# SANG-TODO Download long lora files. @pytest.fixture(scope="session") def long_context_infos(long_context_lora_files_16k_1, long_context_lora_files_16k_2, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index d001d17144d98..a2092d31ea9aa 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -105,8 +105,6 @@ def convert_mapping( lora_offset: int = long_lora_context.offsets_by_lora_id.get( index_mapping_indices[i], 0) long_lora_offsets[i] = lora_offset - # SANG-TODO - # index_mapping_indices[i] = i indices_list: List[Union[List[int], torch.Tensor]] = [ index_mapping_indices, lora_indices, embedding_indices From 8674f9880e2d8574c2adc759027e0f27dc9b95de Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 22 May 2024 10:10:43 -0400 Subject: [PATCH 097/124] [Kernel] Fixup for CUTLASS kernels in CUDA graphs (#4954) Pass the CUDA stream into the CUTLASS GEMMs, to avoid future issues with CUDA graphs --- .../cutlass_w8a8/scaled_mm_dq_c2x.cu | 6 ++- .../cutlass_w8a8/scaled_mm_dq_c3x.cu | 5 ++- tests/kernels/test_cutlass.py | 41 +++++++++++++++++++ 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu index e62fe731a98d3..3a6b8a226e18c 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu @@ -1,6 +1,8 @@ #include #include +#include + // clang-format will break include orders // clang-format off #include "cute/tensor.hpp" @@ -189,8 +191,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, size_t workspace_size = gemm_op.get_workspace_size(args); cutlass::device_memory::allocation workspace(workspace_size); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + CUTLASS_CHECK(gemm_op.can_implement(args)); - cutlass::Status status = gemm_op(args, workspace.get()); + cutlass::Status status = gemm_op(args, workspace.get(), stream); CUTLASS_CHECK(status); } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu index 12efcac7bb919..5fd6d8ff20867 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -1,5 +1,7 @@ #include +#include + #include #include #include @@ -178,7 +180,8 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, size_t workspace_size = gemm_op.get_workspace_size(args); TORCH_CHECK(workspace_size == 0); - cutlass::Status status = gemm_op.run(args); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + cutlass::Status status = gemm_op.run(args, stream); CUTLASS_CHECK(status); } } // namespace diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index fdfd1dee29ce6..2cf0e86e5ca44 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -190,3 +190,44 @@ def test_cutlass_subset(): b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) + + +# Test to make sure cuda graphs work +class CutlassLayer(torch.nn.Module): + + def __init__(self, b, scale_a, scale_b, out_dtype): + super().__init__() + self.b = b + self.scale_a = scale_a + self.scale_b = scale_b + self.out_dtype = out_dtype + + def forward(self, a): + return ops.cutlass_scaled_mm_dq(a, self.b, self.scale_a, self.scale_b, + self.out_dtype) + + +def test_cutlass_cuda_graph(): + m, n, k = 512, 512, 512 + + a = to_int8(torch.randn((m, k), device="cuda")) + b = to_int8(torch.randn((n, k), device="cuda").t()) + + scale_a = (torch.randn((m, 1), device="cuda", dtype=torch.float32) / 10) + scale_b = (torch.randn((1, n), device="cuda", dtype=torch.float32) / 10) + + # Construct a trivial model with a single layer that calls a CUTLASS kernel + model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + out = model(a) + out.zero_() + g.replay() + + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16) + assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) From a3a73ab0696b6692f3eecf80271a01fa97bd001d Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 22 May 2024 13:28:20 -0700 Subject: [PATCH 098/124] [Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893) The 2nd PR for #4532. This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter). --- benchmarks/benchmark_latency.py | 14 ++-- benchmarks/benchmark_throughput.py | 12 ++- .../kernels/benchmark_paged_attention.py | 10 +-- tests/models/test_fp8.py | 80 ++++++++++++------- vllm/attention/layer.py | 27 ++++++- vllm/config.py | 8 +- vllm/engine/arg_utils.py | 7 +- .../model_executor/layers/quantization/fp8.py | 47 ++++++++++- vllm/model_executor/models/arctic.py | 3 +- vllm/model_executor/models/baichuan.py | 6 +- vllm/model_executor/models/bloom.py | 3 +- vllm/model_executor/models/chatglm.py | 13 ++- vllm/model_executor/models/commandr.py | 13 ++- vllm/model_executor/models/dbrx.py | 13 ++- vllm/model_executor/models/deepseek.py | 3 +- vllm/model_executor/models/falcon.py | 9 ++- vllm/model_executor/models/gemma.py | 3 +- vllm/model_executor/models/gpt2.py | 3 +- vllm/model_executor/models/gpt_bigcode.py | 3 +- vllm/model_executor/models/gpt_j.py | 3 +- vllm/model_executor/models/gpt_neox.py | 3 +- vllm/model_executor/models/internlm2.py | 3 +- vllm/model_executor/models/jais.py | 13 ++- vllm/model_executor/models/llama.py | 32 ++++---- vllm/model_executor/models/minicpm.py | 3 +- vllm/model_executor/models/mixtral.py | 29 +++++-- vllm/model_executor/models/mixtral_quant.py | 15 ++-- vllm/model_executor/models/mpt.py | 3 +- vllm/model_executor/models/olmo.py | 3 +- vllm/model_executor/models/opt.py | 3 +- vllm/model_executor/models/orion.py | 3 +- vllm/model_executor/models/phi.py | 3 +- vllm/model_executor/models/qwen.py | 3 +- vllm/model_executor/models/qwen2.py | 3 +- vllm/model_executor/models/qwen2_moe.py | 3 +- vllm/model_executor/models/stablelm.py | 3 +- vllm/model_executor/models/starcoder2.py | 15 ++-- vllm/model_executor/models/xverse.py | 3 +- vllm/utils.py | 2 + vllm/worker/model_runner.py | 17 ++-- 40 files changed, 284 insertions(+), 158 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index f84e3453947c9..a9657f7859750 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -153,15 +153,13 @@ def run_to_completion(profile_dir: Optional[str] = None): action='store_true', help='enforce eager mode and disable CUDA graph') parser.add_argument( - "--kv-cache-dtype", + '--kv-cache-dtype', type=str, - choices=['auto', 'fp8'], - default='auto', - help= - 'Data type for kv cache storage. If "auto", will use model data type. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' - 'instead supported for common inference criteria.') + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default="auto", + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument( '--quantization-param-path', type=str, diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 41f443968c3c4..7c8cb5ee8cea2 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -323,15 +323,13 @@ def main(args: argparse.Namespace): action="store_true", help="enforce eager execution") parser.add_argument( - "--kv-cache-dtype", + '--kv-cache-dtype', type=str, - choices=["auto", "fp8"], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], default="auto", - help= - 'Data type for kv cache storage. If "auto", will use model data type. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' - 'common inference criteria.') + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument( '--quantization-param-path', type=str, diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index ca7967c1ab0d2..fc9621e885dc4 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -183,13 +183,11 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: parser.add_argument( "--kv-cache-dtype", type=str, - choices=["auto", "fp8"], + choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"], default="auto", - help= - 'Data type for kv cache storage. If "auto", will use model data type. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' - 'common inference criteria.') + help="Data type for kv cache storage. If 'auto', will use model " + "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " + "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") args = parser.parse_args() print(args) diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index 664e951a89f2a..0a5819ea3f054 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -16,31 +16,55 @@ MAX_MODEL_LEN = 1024 MODELS = [ - "nm-testing/Meta-Llama-3-8B-Instruct-FP8", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV", "meta-llama/Meta-Llama-3-8B-Instruct", ] EXPECTED_STRS_MAP = { - "nm-testing/Meta-Llama-3-8B-Instruct-FP8": [ - 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', - 'Zeta-5, a highly advanced robot designed for menial labor, whirred to a', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', - ], - "meta-llama/Meta-Llama-3-8B-Instruct": [ - 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' - ], + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": { + "auto": [ + 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no' + ], + "fp8": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system made up of several basic components that work together to enable it to', + 'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk' + ] + }, + "meta-llama/Meta-Llama-3-8B-Instruct": { + "auto": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' + ], + "fp8": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', + 'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu' + ] + }, } capability = torch.cuda.get_device_capability() @@ -52,14 +76,14 @@ @pytest.mark.skipif(fp8_not_supported, reason="fp8 is not supported on this GPU type.") @pytest.mark.parametrize("model_name", MODELS) -def test_models( - example_prompts, - model_name, -) -> None: +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) +def test_models(example_prompts, model_name, kv_cache_dtype) -> None: model = LLM(model=model_name, max_model_len=MAX_MODEL_LEN, + trust_remote_code=True, enforce_eager=True, - quantization="fp8") + quantization="fp8", + kv_cache_dtype=kv_cache_dtype) tokenizer = AutoTokenizer.from_pretrained(model_name) formatted_prompts = [ @@ -81,8 +105,8 @@ def test_models( generations.append(outputs[0].outputs[0].text) del model - print(generations) - expected_strs = EXPECTED_STRS_MAP[model_name] + print(model_name, kv_cache_dtype, generations) + expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype] for i in range(len(example_prompts)): generated_str = generations[i] expected_str = expected_strs[i] diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 4299726bdca4b..dc7b3940bc9b7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -7,6 +7,8 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) class Attention(nn.Module): @@ -30,6 +32,7 @@ def __init__( alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() if cache_config is not None: @@ -40,6 +43,27 @@ def __init__( block_size = 16 if num_kv_heads is None: num_kv_heads = num_heads + + # The default kv_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized kv_scale to be loaded along + # with the model weights. + self.kv_cache_dtype = kv_cache_dtype + self._kv_scale = 1.0 + quant_method = quant_config.get_quant_method( + self) if quant_config else None + if quant_method is not None: + if self.kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with " + "fp8 checkpoints.") + # When FP8 quantization is enabled, we make a parameter + # "kv_scale" so that it can be loaded from FP8 checkpoint. + # The kv_scale will then be converted back + # to self._kv_scale in a native float32 value after weight loading. + self.quant_method = quant_method + self.quant_method.create_weights(self) + # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() @@ -57,10 +81,9 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, - kv_scale: float = 1.0, ) -> torch.Tensor: return self.impl.forward(query, key, value, kv_cache, attn_metadata, - kv_scale) + self._kv_scale) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore diff --git a/vllm/config.py b/vllm/config.py index 3256c11967914..b245a1a3ee6d3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -355,14 +355,12 @@ def _verify_args(self) -> None: def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass - elif self.cache_dtype == "fp8": + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " - "But it may cause slight accuracy drop without scaling " - "factors. FP8_E5M2 (without scaling) is only supported on " - "cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 " - "is instead supported for common inference criteria.") + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0a9ec7472fbca..538e3427e37fb 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -191,12 +191,11 @@ def add_cli_args( parser.add_argument( '--kv-cache-dtype', type=str, - choices=['auto', 'fp8'], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], default=EngineArgs.kv_cache_dtype, help='Data type for kv cache storage. If "auto", will use model ' - 'data type. FP8_E5M2 (without scaling) is only supported on cuda ' - 'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' - 'supported for common inference criteria.') + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument( '--quantization-param-path', type=nullable_str, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ff996741c1d00..b084b9cee4983 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -8,8 +8,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import print_warning_once ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -58,9 +59,13 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": activation_scheme=activation_scheme) def get_quant_method( - self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]: + self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + if isinstance(layer, LinearBase): return Fp8LinearMethod(self) + if isinstance(layer, Attention): + return Fp8KVCacheMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -251,6 +256,44 @@ def apply(self, return torch.narrow(output, 0, 0, x.shape[0]) +class Fp8KVCacheMethod(QuantizeMethodBase): + """Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module): + """Create "weight" (aka kv_scale) for an attention layer. + + Args: + layer: The layer that is using the QuantizeMethodBase factory. + """ + # Initialize the KV cache scale to 1.0 as the default value. + # If the kv_scale appears in the checkpoint, it will be + # overwritten when loading weights. + layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False) + + def apply(self, layer: torch.nn.Module) -> torch.Tensor: + raise RuntimeError("Fp8KVCacheMethod.apply should not be called.") + + def process_weights_after_loading(self, layer: Module) -> None: + # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0 + # regardless whether the kv-scale is available in the checkpoint. + if layer.kv_cache_dtype != "auto": + kv_scale = layer.kv_scale.to("cpu").tolist() + if not isinstance(kv_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") + layer._kv_scale = kv_scale + if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: + print_warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This may " + "cause accuracy issues. Please make sure kv-cache scaling " + "factor is available in the fp8 checkpoint.") + del layer.kv_scale + + def all_close_1d(x: torch.Tensor) -> bool: assert len(x.shape) == 1 return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index cb99939cbb17a..313762b1353d1 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -268,7 +268,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 58b3405d319d1..babb92e7cdcef 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -154,7 +154,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scaling, - alibi_slopes=alibi_slopes) + alibi_slopes=alibi_slopes, + quant_config=quant_config) else: self.rotary_emb = get_rope( self.head_dim, @@ -166,7 +167,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index fe2de87b20dc9..a29aee4cffb7d 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -111,7 +111,8 @@ def __init__( self.head_dim, scaling, alibi_slopes=alibi_slopes, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index ed65d76f7b5b9..e3a5e43e23e1c 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -86,13 +86,12 @@ def __init__( base=10000 * rope_ratio, is_neox_style=False, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 7354d11f98b15..84786921ce1b4 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -177,13 +177,12 @@ def __init__( rope_scaling=self.rope_scaling, is_neox_style=False, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) if self.use_qk_norm: self.q_norm = LayerNorm(param_shape=(self.num_heads, self.head_dim), diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 083ddf0159f71..8ff19a2015e0f 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -218,13 +218,12 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 62e04f9649915..8fbda2638aaa3 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -232,7 +232,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index ab9e1994be426..ba707adb03dfe 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -153,7 +153,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.inv_norm_factor, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + quant_config=quant_config) elif self.use_alibi: tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads @@ -165,13 +166,15 @@ def __init__( self.head_dim, self.inv_norm_factor, num_kv_heads=self.num_kv_heads, - alibi_slopes=alibi_slopes) + alibi_slopes=alibi_slopes, + quant_config=quant_config) else: self.attn = Attention(self.num_heads, self.head_dim, scale=self.inv_norm_factor, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index d1502b718a773..27dda00b66af4 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -157,7 +157,8 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 0deaa58ed9eb5..cc83f6eb6d94d 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -75,7 +75,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index c20fb3230c394..f488ef40039c0 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -88,7 +88,8 @@ def __init__( self.head_dim, scale=self.scale, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 5f4d8ec3d3a7a..47fd5788a4c35 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -88,7 +88,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_size, scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index dcb52ff666c95..eb0fcc8f26a58 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -89,7 +89,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_size, scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 65f7ddb8b082c..e75c567f589c8 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -117,7 +117,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index df30fd1ba0a37..869b8fc91fd64 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -105,13 +105,12 @@ def __init__( head_end = (tp_rank + 1) * self.num_heads alibi_slopes = _get_alibi_slopes(total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end] - self.attn = Attention( - self.num_heads, - self.head_dim, - scale=self.scale, - alibi_slopes=alibi_slopes, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f2996c240aaf4..23141124e69e1 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -47,7 +47,7 @@ default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput -from vllm.utils import is_hip +from vllm.utils import is_hip, print_warning_once class LlamaMLP(nn.Module): @@ -119,15 +119,6 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - # This will be overwritten by model initialization if we are using it. - # N.B. currently we only support per tensor scalar scaling factors - # & only applicable to ROCm (AMD GPU). - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - self.kv_scale = 1.0 - self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, @@ -155,7 +146,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=sliding_window, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -167,8 +159,7 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata, - self.kv_scale) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -421,6 +412,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + f"Found kv scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_kv_scale_name}). kv-scale is " + "not loaded.") + continue + else: + name = remapped_kv_scale_name param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -445,7 +449,7 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: # scaling_factor = tensor_amax / FPtype_max scaling_factor *= 2 if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.kv_scale = scaling_factor + layer_self_attn.attn._kv_scale = scaling_factor else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 0b85cf1c94795..59fbf8e1b35f2 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -236,7 +236,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e3ac33e0452fe..ea95cf7380d54 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -308,14 +308,13 @@ def __init__(self, base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -581,6 +580,20 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index ee2626b1c1aa2..9b99ff729aadd 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -213,14 +213,13 @@ def __init__( base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 716ac51cde94d..5f9e4d86f3cd8 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -110,7 +110,8 @@ def __init__( scaling, alibi_slopes=alibi_slopes, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 69f23bbfb5d0a..39270f71ec46f 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -96,7 +96,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) # Attention output projection. self.o_proj = RowParallelLinear( diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index d241756e50f4a..4bf59105dbabb 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -91,7 +91,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 59cd42e31b374..133a10e6bb3e8 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -121,7 +121,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 193a29d20c894..c8e61735a9bb6 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -110,7 +110,8 @@ def __init__(self, self.attn = Attention(self.num_heads, self.head_size, scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index d158846a3a1f5..d22ea6b79de0f 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -106,7 +106,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 97ab6168c3230..ec203c3b9001a 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -141,7 +141,8 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index a0d3b0406ef4a..564536f2dd248 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -241,7 +241,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 8b4a5507feade..a6ed3800bed0f 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -127,7 +127,8 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_key_value_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 3c19d63276a77..91ffd0861c39d 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -97,14 +97,13 @@ def __init__(self, base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 6ef230a8ebbca..dda13d83f89a3 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -135,7 +135,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=sliding_window, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/utils.py b/vllm/utils.py index 552b43e7f82b2..4cb9d905097bf 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -31,6 +31,8 @@ "bfloat16": torch.bfloat16, "float": torch.float, "fp8": torch.uint8, + "fp8_e4m3": torch.uint8, + "fp8_e5m2": torch.uint8, } diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e264fede0ee64..9720363ac300e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,4 +1,5 @@ import time +import warnings from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union import numpy as np @@ -168,11 +169,21 @@ def load_model(self) -> None: self.model = self.lora_manager.create_lora_manager(self.model) if self.kv_cache_dtype == "fp8" and is_hip(): - # Currently scaled KV cache is only enabled on ROCm + # Currently only ROCm accepts kv-cache scaling factors + # via quantization_param_path and this will be deprecated + # in the future. if self.model_config.quantization_param_path is not None: if callable(getattr(self.model, "load_kv_cache_scales", None)): + warnings.warn( + "Loading kv cache scaling factor from JSON is " + "deprecated and will be removed. Please include " + "kv cache scaling factors in the model checkpoint.", + FutureWarning, + stacklevel=2) self.model.load_kv_cache_scales( self.model_config.quantization_param_path) + logger.info("Loaded KV cache scaling factors from %s", + self.model_config.quantization_param_path) else: raise RuntimeError( "Using FP8 KV cache and scaling factors provided but " @@ -183,10 +194,6 @@ def load_model(self) -> None: "Using FP8 KV cache but no scaling factors " "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") - elif self.model_config.quantization_param_path is not None: - logger.warning("KV cache scaling factors provided, " - "but the KV cache data type is not FP8. " - "KV cache scaling factors will not be used.") def save_sharded_state( self, From 97b030005c7f5cde7c1b97c718a8841db7d6220b Mon Sep 17 00:00:00 2001 From: raywanb <112235519+raywanb@users.noreply.github.com> Date: Thu, 23 May 2024 04:58:59 +0800 Subject: [PATCH 099/124] [Model] LoRA gptbigcode implementation (#3949) --- csrc/punica/bgmv/bgmv_config.h | 4 +++ tests/lora/test_punica.py | 2 ++ vllm/lora/models.py | 2 ++ vllm/model_executor/models/gpt_bigcode.py | 31 +++++++++++++++++++---- 4 files changed, 34 insertions(+), 5 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 98ac8de779e13..4b376261d30d2 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 2752) \ f(in_T, out_T, W_T, narrow, 2816) \ f(in_T, out_T, W_T, narrow, 3072) \ + f(in_T, out_T, W_T, narrow, 3328) \ f(in_T, out_T, W_T, narrow, 3456) \ f(in_T, out_T, W_T, narrow, 3584) \ f(in_T, out_T, W_T, narrow, 4096) \ @@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5632) \ f(in_T, out_T, W_T, narrow, 6144) \ + f(in_T, out_T, W_T, narrow, 6400) \ f(in_T, out_T, W_T, narrow, 6848) \ f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 7168) \ @@ -97,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 2752, narrow) \ f(in_T, out_T, W_T, 2816, narrow) \ f(in_T, out_T, W_T, 3072, narrow) \ + f(in_T, out_T, W_T, 3328, narrow) \ f(in_T, out_T, W_T, 3456, narrow) \ f(in_T, out_T, W_T, 3584, narrow) \ f(in_T, out_T, W_T, 4096, narrow) \ @@ -105,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 5504, narrow) \ f(in_T, out_T, W_T, 5632, narrow) \ f(in_T, out_T, W_T, 6144, narrow) \ + f(in_T, out_T, W_T, 6400, narrow) \ f(in_T, out_T, W_T, 6848, narrow) \ f(in_T, out_T, W_T, 6912, narrow) \ f(in_T, out_T, W_T, 7168, narrow) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 193e3906997c4..f021c003b1322 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -58,6 +58,7 @@ def _lora_ref_impl( 2560, 2752, 3072, + 3328, 3456, 3584, 4096, @@ -66,6 +67,7 @@ def _lora_ref_impl( 5504, 5632, 6144, + 6400, 6848, 6912, 7168, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index a2092d31ea9aa..3e82856866d85 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -310,7 +310,9 @@ def from_local_checkpoint( if part_name not in expected_lora_modules: unexpected_modules.append(module) # loaded lora's target modules must be a subset of expected_lora_modules + if unexpected_modules: + print(unexpected_modules, "modules") raise ValueError( f"While loading {lora_dir}, expected" f" target modules in {expected_lora_modules}" diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index f488ef40039c0..69b75763e9a3d 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -25,7 +25,7 @@ from transformers import GPTBigCodeConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -191,14 +191,19 @@ def __init__( config: GPTBigCodeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config assert not config.add_cross_attention self.embed_dim = config.hidden_size - - self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.wte = VocabParallelEmbedding(self.vocab_size, + self.embed_dim, + org_num_embeddings=config.vocab_size) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ GPTBigCodeBlock(config, cache_config, quant_config) @@ -226,19 +231,35 @@ def forward( class GPTBigCodeForCausalLM(nn.Module): + packed_modules_mapping = {"c_attn": ["c_attn"]} + + supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"] + + embedding_modules = { + "wte": "input_embeddings", + "lm_head": "output_embeddings", + } + + embedding_padding_modules = [] def __init__( self, config: GPTBigCodeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = GPTBigCodeModel(config, cache_config, quant_config) + self.transformer = GPTBigCodeModel(config, cache_config, quant_config, + lora_config) self.lm_head_weight = self.transformer.wte.weight - self.logits_processor = LogitsProcessor(config.vocab_size) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) self.sampler = Sampler() def forward( From eb6d3c264d0cd8e44dec16bca7947fbe96415ce9 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 22 May 2024 14:17:27 -0700 Subject: [PATCH 100/124] [Core] Eliminate parallel worker per-step task scheduling overhead (#4894) --- vllm/engine/async_llm_engine.py | 10 +- vllm/engine/llm_engine.py | 8 ++ vllm/executor/distributed_gpu_executor.py | 123 ++++++++++++++++----- vllm/executor/executor_base.py | 8 ++ vllm/executor/multiproc_gpu_executor.py | 73 ++++++++----- vllm/executor/ray_gpu_executor.py | 86 +++++++-------- vllm/spec_decode/ngram_worker.py | 4 +- vllm/spec_decode/spec_decode_worker.py | 125 +++++++++++----------- vllm/worker/embedding_model_runner.py | 5 +- vllm/worker/model_runner.py | 5 +- vllm/worker/worker.py | 103 +++++++++++------- vllm/worker/worker_base.py | 7 +- 12 files changed, 348 insertions(+), 209 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8a37bac02823a..5a15ed67e3327 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -234,6 +234,14 @@ async def step_async( # Log stats. self.do_log_stats(scheduler_outputs, output) + if not request_outputs: + # Stop the execute model loop in parallel workers until there are + # more requests to process. This avoids waiting indefinitely in + # torch.distributed ops which may otherwise timeout, and unblocks + # the RPC thread in the workers so that they can process any other + # queued control plane messages, such as add/remove lora adapters. + await self.model_executor.stop_remote_worker_execution_loop_async() + return request_outputs async def encode_request_async( @@ -687,7 +695,7 @@ async def encode( multi_modal_data: Multi modal data per request. Yields: - The output `EmbeddingRequestOutput` objects from the LLMEngine + The output `EmbeddingRequestOutput` objects from the LLMEngine for the request. Details: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 60e23d4df15bb..0631c0de76822 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -692,6 +692,14 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Log stats. self.do_log_stats(scheduler_outputs, output) + if not request_outputs: + # Stop the execute model loop in parallel workers until there are + # more requests to process. This avoids waiting indefinitely in + # torch.distributed ops which may otherwise timeout, and unblocks + # the RPC thread in the workers so that they can process any other + # queued control plane messages, such as add/remove lora adapters. + self.model_executor.stop_remote_worker_execution_loop() + return request_outputs def do_log_stats( diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index c5b1e61112afb..f7c608af1ad39 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -1,11 +1,12 @@ +import asyncio from abc import abstractmethod -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SamplerOutput logger = init_logger(__name__) @@ -13,6 +14,16 @@ class DistributedGPUExecutor(GPUExecutor): """Abstract superclass of multi-GPU executor implementations.""" + def __init__(self, *args, **kwargs): + # This is non-None when the execute model loop is running + # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None + # Updated by implementations that require additional args to be passed + # to the _run_workers execute_model call + self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} + + super().__init__(*args, **kwargs) + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. @@ -52,13 +63,28 @@ def initialize_cache(self, num_gpu_blocks: int, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) - def execute_model(self, *args, **kwargs) -> List[SamplerOutput]: - all_outputs = self._run_workers("execute_model", - driver_args=args, - driver_kwargs=kwargs) + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + self.parallel_worker_tasks = self._run_workers( + "start_worker_execution_loop", + async_run_remote_workers_only=True, + **self.extra_execute_model_run_workers_kwargs) # Only the driver worker returns the sampling results. - return all_outputs[0] + return self._driver_execute_model(execute_model_req) + + def stop_remote_worker_execution_loop(self) -> None: + if self.parallel_worker_tasks is None: + return + + self._driver_execute_model() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + self._wait_for_tasks_completion(parallel_worker_tasks) def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." @@ -88,39 +114,84 @@ def save_sharded_state( pattern=pattern, max_size=max_size) + @abstractmethod + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + raise NotImplementedError + @abstractmethod def _run_workers( self, method: str, *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, + async_run_remote_workers_only: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: - """Runs the given method on all workers.""" + """Runs the given method on all workers. + + Args: + async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than + blocking on the results. + """ + raise NotImplementedError + + @abstractmethod + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" raise NotImplementedError class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase): + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + # Start model execution loop running in the parallel workers + self.parallel_worker_tasks = asyncio.create_task( + self._start_worker_execution_loop()) + + # Only the driver worker returns the sampling results. + return await self._driver_execute_model_async(execute_model_req) + + async def stop_remote_worker_execution_loop_async(self) -> None: + if self.parallel_worker_tasks is None: + return + + await self._driver_execute_model_async() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + await parallel_worker_tasks + @abstractmethod - async def _run_workers_async( + async def _driver_execute_model_async( self, - method: str, - *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers.""" - raise NotImplementedError + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Execute the model asynchronously in the driver worker. - async def execute_model_async(self, *args, - **kwargs) -> List[SamplerOutput]: - all_outputs = await self._run_workers_async("execute_model", - driver_args=args, - driver_kwargs=kwargs) + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + raise NotImplementedError - # Only the driver worker returns the sampling results. - return all_outputs[0] + @abstractmethod + async def _start_worker_execution_loop(self): + """Run execution loop on all workers. It guarantees all workers run + the loop or None of them is running the loop. Loop can be stopped by + `stop_remote_worker_execution_loop`. + The API is idempotent (guarantee only 1 loop run at any moment).""" + raise NotImplementedError diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 08aa58999b1ec..4d01939c2e38b 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -74,6 +74,10 @@ def execute_model( """Executes at least one model step on the given sequences.""" raise NotImplementedError + def stop_remote_worker_execution_loop(self) -> None: + """Releases parallel workers from model loop.""" + return + @abstractmethod def add_lora(self, lora_request: LoRARequest) -> bool: raise NotImplementedError @@ -109,6 +113,10 @@ async def execute_model_async( """Executes one model step on the given sequences.""" raise NotImplementedError + async def stop_remote_worker_execution_loop_async(self) -> None: + """Releases parallel workers from model loop.""" + return + async def check_health_async(self) -> None: """Checks if the executor is healthy. If not, it should raise an exception.""" diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 2a7b99c9dcbe1..8fa54454907b5 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -1,13 +1,14 @@ import asyncio import os from functools import partial -from typing import Any, Dict, Optional, Tuple +from typing import Any, List, Optional from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, ResultHandler, WorkerMonitor) from vllm.logger import init_logger +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) @@ -71,16 +72,34 @@ def shutdown(self): None)) is not None: worker_monitor.close() + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + return self.driver_worker.execute_model( + execute_model_req=execute_model_req) + def _run_workers( self, method: str, *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, + async_run_remote_workers_only: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: - """Runs the given method on all workers.""" + """Runs the given method on all workers. + + Args: + async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than + blocking on the results. + """ if max_concurrent_workers: raise NotImplementedError( @@ -92,15 +111,12 @@ def _run_workers( for worker in self.workers ] - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs + if async_run_remote_workers_only: + # Just return futures + return worker_outputs - # Start the driver worker after all the ray workers. driver_worker_method = getattr(self.driver_worker, method) - driver_worker_output = driver_worker_method(*driver_args, - **driver_kwargs) + driver_worker_output = driver_worker_method(*args, **kwargs) # Get the results of the workers. return [driver_worker_output @@ -111,30 +127,29 @@ def check_health(self) -> None: if not self.worker_monitor.is_alive(): raise RuntimeError("Worker processes are not running") + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + for result in parallel_worker_tasks: + result.get() + class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor, DistributedGPUExecutorAsync): - async def _run_workers_async( - self, - method: str, - *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers.""" - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_model = make_async(self.driver_worker.execute_model) - driver_executor = make_async(getattr(self.driver_worker, method)) + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + return await self.driver_exec_model(execute_model_req) - # Run all the workers asynchronously. - coros = [driver_executor(*driver_args, **driver_kwargs)] + [ - worker.execute_method_async(method, *args, **kwargs) + async def _start_worker_execution_loop(self): + coros = [ + worker.execute_method_async("start_worker_execution_loop") for worker in self.workers ] - return await asyncio.gather(*coros) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index dd3ee60682d30..bed356d1b6e58 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -42,6 +42,8 @@ def _init_executor(self) -> None: self.forward_dag = None if USE_RAY_COMPILED_DAG: self.forward_dag = self._compiled_ray_dag() + self.extra_execute_model_run_workers_kwargs[ + "use_ray_compiled_dag"] = True def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> Dict[str, Any]: @@ -171,23 +173,23 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) - def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - all_outputs = self._run_workers( - "execute_model", - driver_kwargs={"execute_model_req": execute_model_req}, - use_ray_compiled_dag=USE_RAY_COMPILED_DAG) + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. - # Only the driver worker returns the sampling results. - return all_outputs[0] + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + return self.driver_worker.execute_method("execute_model", + execute_model_req) def _run_workers( self, method: str, *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, + async_run_remote_workers_only: bool = False, all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, @@ -198,9 +200,11 @@ def _run_workers( """Runs the given method on all workers. Can be used in the following ways: + - async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than blocking + on the results. - args/kwargs: All workers share the same args/kwargs - - args/kwargs and driver_args/driver_kwargs: Driver worker has - different args - all_args/all_kwargs: args/kwargs for each worker are specified individually """ @@ -209,11 +213,6 @@ def _run_workers( raise NotImplementedError( "max_concurrent_workers is not supported yet.") - if driver_args is None: - driver_args = args if all_args is None else all_args[0] - if driver_kwargs is None: - driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] - count = len(self.workers) all_worker_args = repeat(args, count) if all_args is None \ else islice(all_args, 1, None) @@ -225,6 +224,7 @@ def _run_workers( # input. TODO(sang): Fix it. assert self.forward_dag is not None output_channels = self.forward_dag.execute(1) + ray_worker_outputs = [] else: # Start the ray workers first. ray_worker_outputs = [ @@ -234,6 +234,13 @@ def _run_workers( ) in zip(self.workers, all_worker_args, all_worker_kwargs) ] + if async_run_remote_workers_only: + # Just return futures + return ray_worker_outputs + + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + # Start the driver worker after all the ray workers. if not use_dummy_driver: driver_worker_output = self.driver_worker.execute_method( @@ -260,6 +267,11 @@ def _run_workers( return [driver_worker_output] + ray_worker_outputs + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + ray.get(parallel_worker_tasks) + def _compiled_ray_dag(self): import pkg_resources required_version = "2.9" @@ -303,30 +315,18 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.driver_executor = make_async(self.driver_worker.execute_method) + self.driver_exec_method = make_async(self.driver_worker.execute_method) - async def _run_workers_async( + async def _driver_execute_model_async( self, - method: str, - *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers.""" - coros = [] - - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs - - coros.append( - self.driver_executor(method, *driver_args, **driver_kwargs)) - - # Run the ray workers asynchronously. - for worker in self.workers: - coros.append(worker.execute_method.remote(method, *args, **kwargs)) - - all_outputs = await asyncio.gather(*coros) - return all_outputs + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + return await self.driver_exec_method("execute_model", + execute_model_req) + + async def _start_worker_execution_loop(self): + coros = [ + worker.execute_method.remote("start_worker_execution_loop") + for worker in self.workers + ] + return await asyncio.gather(*coros) diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 9628f7af5315a..c2b22f2acd7b4 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -47,7 +47,9 @@ def set_include_gpu_probs_tensor(self): # NGram don't need gpu sampler pass - def execute_model(self, execute_model_req: ExecuteModelRequest) -> None: + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None) -> None: """NGram doesn't depend on model execution, just pass this function""" pass diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index ef17b8c1e2cc0..3462a876c3e90 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -231,35 +231,6 @@ def initialize_cache(self, num_gpu_blocks: int, self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) - def _broadcast_control_flow_decision( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - disable_all_speculation: bool = False) -> Tuple[int, bool]: - """Broadcast how many lookahead slots are scheduled for this step, and - whether all speculation is disabled, to all non-driver workers. - - This is required as if the number of draft model runs changes - dynamically, the non-driver workers won't know unless we perform a - communication to inform then. - - Returns the broadcasted num_lookahead_slots and disable_all_speculation. - """ - - if self.rank == self._driver_rank: - assert execute_model_req is not None - - broadcast_dict = dict( - num_lookahead_slots=execute_model_req.num_lookahead_slots, - disable_all_speculation=disable_all_speculation, - ) - broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) - else: - assert execute_model_req is None - broadcast_dict = broadcast_tensor_dict(src=self._driver_rank) - - return (broadcast_dict["num_lookahead_slots"], - broadcast_dict["disable_all_speculation"]) - @torch.inference_mode() def execute_model( self, @@ -267,39 +238,58 @@ def execute_model( ) -> List[SamplerOutput]: """Perform speculative decoding on the input batch. """ + if self.rank != self._driver_rank: + self._run_non_driver_rank() + return [] - disable_all_speculation = False - if self.rank == self._driver_rank: - disable_all_speculation = self._should_disable_all_speculation( - execute_model_req) - - (num_lookahead_slots, - disable_all_speculation) = self._broadcast_control_flow_decision( - execute_model_req, disable_all_speculation) - - if self.rank == self._driver_rank: - assert execute_model_req is not None - assert execute_model_req.seq_group_metadata_list is not None, ( - "speculative decoding requires non-None seq_group_metadata_list" - ) - - self._maybe_disable_speculative_tokens( - disable_all_speculation, - execute_model_req.seq_group_metadata_list) - - # If no spec tokens, call the proposer and scorer workers normally. - # Used for prefill. - if num_lookahead_slots == 0 or len( - execute_model_req.seq_group_metadata_list) == 0: - return self._run_no_spec(execute_model_req, - skip_proposer=disable_all_speculation) - - return self._run_speculative_decoding_step(execute_model_req, - num_lookahead_slots) - else: - self._run_non_driver_rank(num_lookahead_slots) + if execute_model_req is None: + # This signals that there's no more requests to process for now. + # All workers are running infinite loop with broadcast_tensor_dict, + # and it stops the loop when the driver broadcasts an empty input. + # Send an empty input to notify all other workers to stop their + # execution loop. + broadcast_tensor_dict({}, src=0) return [] + disable_all_speculation = self._should_disable_all_speculation( + execute_model_req) + num_lookahead_slots = execute_model_req.num_lookahead_slots + + # Broadcast how many lookahead slots are scheduled for this step, and + # whether all speculation is disabled, to all non-driver workers. + + # This is required as if the number of draft model runs changes + # dynamically, the non-driver workers won't know unless we perform a + # communication to inform then. + broadcast_dict = dict( + num_lookahead_slots=num_lookahead_slots, + disable_all_speculation=disable_all_speculation, + ) + broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) + + assert execute_model_req.seq_group_metadata_list is not None, ( + "speculative decoding requires non-None seq_group_metadata_list") + + self._maybe_disable_speculative_tokens( + disable_all_speculation, execute_model_req.seq_group_metadata_list) + + # If no spec tokens, call the proposer and scorer workers normally. + # Used for prefill. + if num_lookahead_slots == 0 or len( + execute_model_req.seq_group_metadata_list) == 0: + return self._run_no_spec(execute_model_req, + skip_proposer=disable_all_speculation) + + return self._run_speculative_decoding_step(execute_model_req, + num_lookahead_slots) + + @torch.inference_mode() + def start_worker_execution_loop(self) -> None: + """Execute model loop to perform speculative decoding + in parallel worker.""" + while self._run_non_driver_rank(): + pass + def _should_disable_all_speculation( self, execute_model_req: ExecuteModelRequest) -> bool: # When the batch size is too large, disable speculative decoding @@ -346,13 +336,19 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, sampler_output.logprobs = None return [sampler_output] - def _run_non_driver_rank(self, num_lookahead_slots: int) -> None: + def _run_non_driver_rank(self) -> bool: """Run proposer and verifier model in non-driver workers. This is used for both speculation cases (num_lookahead_slots>0) and non-speculation cases (e.g. prefill). + + Returns True iff there are remaining sequences to process. """ - # In non-driver workers the input is None - execute_model_req = None + assert self.rank != self._driver_rank + + data = broadcast_tensor_dict(src=self._driver_rank) + if not data: + return False + num_lookahead_slots = data["num_lookahead_slots"] # Even if num_lookahead_slots is zero, we want to run the proposer model # as it may have KV. @@ -360,9 +356,10 @@ def _run_non_driver_rank(self, num_lookahead_slots: int) -> None: # We run the proposer once per lookahead slot. In the future we should # delegate how many times it runs to the proposer. for _ in range(max(num_lookahead_slots, 1)): - self.proposer_worker.execute_model(execute_model_req) + self.proposer_worker.execute_model() - self.scorer_worker.execute_model(execute_model_req) + self.scorer_worker.execute_model() + return True @nvtx_range("spec_decode_worker._run_speculative_decoding_step") def _run_speculative_decoding_step( diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 91f30978ead87..ef02de95fc54e 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -47,7 +47,7 @@ def __init__( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[torch.Tensor], ) -> Optional[PoolerOutput]: (input_tokens, input_positions, attn_metadata, pooling_metadata, @@ -84,10 +84,11 @@ def execute_model( def prepare_input_tensors( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: + assert seq_group_metadata_list is not None # Prepare input tensors. ( input_tokens, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9720363ac300e..87d5f5c1b9d67 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -609,10 +609,11 @@ def _prepare_model_input( def prepare_input_tensors( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: + assert seq_group_metadata_list is not None # Prepare input tensors. ( input_tokens, @@ -676,7 +677,7 @@ def prepare_input_tensors( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 97b3873b2a9f6..10411a2bf7a10 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -226,48 +226,42 @@ def execute_model( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[Union[SamplerOutput, PoolerOutput]]: + if not self.is_driver_worker: + self._execute_model_non_driver() + return [] if execute_model_req is None: - seq_group_metadata_list = None - else: - seq_group_metadata_list = execute_model_req.seq_group_metadata_list + # This signals that there's no more requests to process for now. + # All workers are running infinite loop with broadcast_tensor_dict, + # and it stops the loop when the driver broadcasts an empty input. + # Send an empty input to notify all other workers to stop their + # execution loop. + broadcast_tensor_dict({}, src=0) + return [] - blocks_to_swap_in: torch.Tensor - blocks_to_swap_out: torch.Tensor - blocks_to_copy: torch.Tensor - if self.is_driver_worker: - assert seq_group_metadata_list is not None - assert execute_model_req is not None - num_seq_groups = len(seq_group_metadata_list) - # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. - # they contain parameters to launch cudamemcpyasync. - blocks_to_swap_in = torch.tensor( - execute_model_req.blocks_to_swap_in, - device="cpu", - dtype=torch.int64).view(-1, 2) - blocks_to_swap_out = torch.tensor( - execute_model_req.blocks_to_swap_out, - device="cpu", - dtype=torch.int64).view(-1, 2) - # `blocks_to_copy` is a gpu tensor. The src and tgt of - # blocks to copy are in the same device, and `blocks_to_copy` - # can be used directly within cuda kernels. - blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, - device=self.device, + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + num_seq_groups = len(seq_group_metadata_list) + # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. + # they contain parameters to launch cudamemcpyasync. + blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, + device="cpu", + dtype=torch.int64).view(-1, 2) + blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, + device="cpu", dtype=torch.int64).view(-1, 2) - data: Dict[str, Any] = { - "num_seq_groups": num_seq_groups, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - } - broadcast_tensor_dict(data, src=0) - else: - data = broadcast_tensor_dict(src=0) - num_seq_groups = data["num_seq_groups"] - blocks_to_swap_in = data["blocks_to_swap_in"] - blocks_to_swap_out = data["blocks_to_swap_out"] - blocks_to_copy = data["blocks_to_copy"] + # `blocks_to_copy` is a gpu tensor. The src and tgt of + # blocks to copy are in the same device, and `blocks_to_copy` + # can be used directly within cuda kernels. + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device=self.device, + dtype=torch.int64).view(-1, 2) + data: Dict[str, Any] = { + "num_seq_groups": num_seq_groups, + "blocks_to_swap_in": blocks_to_swap_in, + "blocks_to_swap_out": blocks_to_swap_out, + "blocks_to_copy": blocks_to_copy, + } + broadcast_tensor_dict(data, src=0) self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) @@ -282,6 +276,39 @@ def execute_model( # to conform to interface. return [output] + @torch.inference_mode() + def start_worker_execution_loop(self) -> None: + """Execute model loop in parallel worker. + + You can stop the loop by executing a driver worker with an empty output. + See `stop_remote_worker_execution_loop` for more details. + """ + while self._execute_model_non_driver(): + pass + + def _execute_model_non_driver(self) -> bool: + """Execute model in parallel worker. + + Returns True iff there are remaining sequences to process. + """ + assert not self.is_driver_worker + data = broadcast_tensor_dict(src=0) + if not data: + return False + + num_seq_groups = data.get("num_seq_groups", 0) + blocks_to_swap_in = data.get("blocks_to_swap_in") + blocks_to_swap_out = data.get("blocks_to_swap_out") + blocks_to_copy = data.get("blocks_to_copy") + self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return False + + self.model_runner.execute_model(None, self.gpu_cache) + return True + def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 1f04f821eb0f0..dbac1b5ba339b 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,7 +1,7 @@ import importlib import os from abc import ABC, abstractmethod -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -48,8 +48,9 @@ def initialize_cache(self, num_gpu_blocks: int, @abstractmethod def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" raise NotImplementedError From a36de682d4283c60777bc3022ed3ce71cd90b904 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 22 May 2024 15:26:56 -0700 Subject: [PATCH 101/124] [Minor] Fix small typo in llama.py: QKVParallelLinear -> QuantizationConfig (#4991) --- vllm/model_executor/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 23141124e69e1..f43a40a0bfd34 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -57,7 +57,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QKVParallelLinear] = None, + quant_config: Optional[QuantizationConfig] = None, bias: bool = False, ) -> None: super().__init__() From ee3eea0a1b2c690557455d97074d8829d5a98320 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 22 May 2024 15:55:56 -0700 Subject: [PATCH 102/124] [Misc] Take user preference in attention selector (#4960) --- tests/kernels/test_attention_selector.py | 84 +++++++++++++ vllm/attention/backends/flashinfer.py | 1 + vllm/attention/selector.py | 145 +++++++++++++---------- 3 files changed, 169 insertions(+), 61 deletions(-) create mode 100644 tests/kernels/test_attention_selector.py diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py new file mode 100644 index 0000000000000..f439afa9b7d2b --- /dev/null +++ b/tests/kernels/test_attention_selector.py @@ -0,0 +1,84 @@ +import os +from unittest.mock import patch + +import pytest +import torch + +from vllm.attention.selector import which_attn_to_use + + +@pytest.mark.parametrize( + "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) +@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) +def test_env(name: str, device: str): + """Test that the attention selector can be set via environment variable. + Note that we do not test FlashAttn because it is the default backend. + """ + name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) + os.environ["VLLM_ATTENTION_BACKEND"] = name + + if device == "cpu": + with patch("vllm.attention.selector.is_cpu", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "TORCH_SDPA" + elif device == "hip": + with patch("vllm.attention.selector.is_hip", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "ROCM_FLASH" + else: + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == name + + if name_backup is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + + +def test_flash_attn(): + """Test FlashAttn validation.""" + name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) + os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" + + # Unsupported CUDA arch + with patch("torch.cuda.get_device_capability", return_value=[7, 5]): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" + + # Unsupported data type + backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) + assert backend.name != "FLASH_ATTN" + + # Unsupported kv cache data type + backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) + assert backend.name != "FLASH_ATTN" + + # Unsupported block size + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) + assert backend.name != "FLASH_ATTN" + + # Unsupported sliding window + backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" + + # flash-attn is not installed + with patch.dict('sys.modules', {'vllm_flash_attn': None}): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" + + # Unsupported head size + backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" + + if name_backup is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + + +def test_invalid_env(): + """Throw an exception if the backend name is invalid.""" + name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) + os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID" + with pytest.raises(ValueError): + which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + os.environ["VLLM_ATTENTION_BACKEND"] = name_backup diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 7210fefbd8162..7b7959d257fac 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -218,6 +218,7 @@ def forward( ) if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. assert prefill_meta.block_tables is not None if kv_cache is None or prefill_meta.block_tables.numel() == 0: output = flash_attn_varlen_func( diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 51c25a81b4130..f191461dcd3b7 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -30,24 +30,16 @@ def get_attn_backend( kv_cache_dtype: Optional[str], block_size: int, ) -> Type[AttentionBackend]: - backend = _which_attn_to_use(num_heads, head_size, num_kv_heads, - sliding_window, dtype, kv_cache_dtype, - block_size) + """Determine which attention backend to use and only import + the selected backend module. + """ + backend = which_attn_to_use(num_heads, head_size, num_kv_heads, + sliding_window, dtype, kv_cache_dtype, + block_size) if backend == _Backend.FLASH_ATTN: from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) - - # We check it here not in _which_attn_to_use because we cannot know - # the head size until we import FlashAttentionBackend. - supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes() - if head_size in supported_head_sizes: - logger.info("Using FlashAttention-2 backend.") - return FlashAttentionBackend - logger.info( - "Cannot use FlashAttention-2 backend for head size %d. " - "Using XFormers backend instead.", head_size) - backend = _Backend.XFORMERS - + return FlashAttentionBackend if backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 @@ -64,14 +56,15 @@ def get_attn_backend( return TorchSDPABackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") - logger.warning("Eager mode is enforced for the Flashinfer backend.") + logger.warning("Eager mode is required for the Flashinfer backend. " + "Please make sure --enforce-eager is set.") from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend else: raise ValueError("Invalid attention backend.") -def _which_attn_to_use( +def which_attn_to_use( num_heads: int, head_size: int, num_kv_heads: int, @@ -81,54 +74,84 @@ def _which_attn_to_use( block_size: int, ) -> _Backend: """Returns which flash attention backend to use.""" + + # Default case. + selected_backend = _Backend.FLASH_ATTN + + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + backend_members = _Backend.__members__ + if backend_by_env_var not in backend_members: + raise ValueError( + f"Invalid attention backend '{backend_by_env_var}'. " + f"Available backends: {', '.join(backend_members)} " + "(case-sensitive).") + selected_backend = _Backend[backend_by_env_var] + if is_cpu(): + if selected_backend != _Backend.TORCH_SDPA: + logger.info("Cannot use %s backend on CPU.", selected_backend) return _Backend.TORCH_SDPA if is_hip(): # AMD GPUs. - if torch.cuda.get_device_capability()[0] != 9: - # not Instinct series GPUs. - logger.info("flash_atten is not supported on NAVI GPUs.") + selected_backend = (_Backend.ROCM_FLASH if selected_backend + == _Backend.FLASH_ATTN else selected_backend) + if selected_backend == _Backend.ROCM_FLASH: + if torch.cuda.get_device_capability()[0] != 9: + # not Instinct series GPUs. + logger.info("flash_attn is not supported on NAVI GPUs.") + else: + logger.info("%s is not supported in AMD GPUs.", selected_backend) return _Backend.ROCM_FLASH - # NVIDIA GPUs. - if torch.cuda.get_device_capability()[0] < 8: - # Volta and Turing NVIDIA GPUs. - logger.info("Cannot use FlashAttention-2 backend for Volta and Turing " - "GPUs.") - return _Backend.XFORMERS - - if dtype not in (torch.float16, torch.bfloat16): - logger.info("Cannot use FlashAttention-2 backend for dtype other than " - "torch.float16 or torch.bfloat16.") - return _Backend.XFORMERS - - if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): - logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.") - return _Backend.XFORMERS - - if block_size % 16 != 0: - logger.info("Cannot use FlashAttention-2 backend for block size not " - "divisible by 16.") - return _Backend.XFORMERS - - if sliding_window is not None: - logger.info( - "Cannot use FlashAttention-2 backend due to sliding window.") - return _Backend.XFORMERS - - try: - import vllm_flash_attn # noqa: F401 - except ImportError: - logger.info( - "Cannot use FlashAttention-2 backend because the vllm_flash_attn " - "package is not found. `pip install vllm-flash-attn` for better " - "performance.") - return _Backend.XFORMERS - - backend_by_env_var = envs.VLLM_ATTENTION_BACKEND - if backend_by_env_var is not None: - return _Backend[backend_by_env_var] - - # Default case. - return _Backend.FLASH_ATTN + # FlashAttn in NVIDIA GPUs. + if selected_backend == _Backend.FLASH_ATTN: + if torch.cuda.get_device_capability()[0] < 8: + # Volta and Turing NVIDIA GPUs. + logger.info( + "Cannot use FlashAttention-2 backend for Volta and Turing " + "GPUs.") + selected_backend = _Backend.XFORMERS + elif dtype not in (torch.float16, torch.bfloat16): + logger.info( + "Cannot use FlashAttention-2 backend for dtype other than " + "torch.float16 or torch.bfloat16.") + selected_backend = _Backend.XFORMERS + elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): + logger.info( + "Cannot use FlashAttention-2 backend for FP8 KV cache.") + selected_backend = _Backend.XFORMERS + elif block_size % 16 != 0: + logger.info( + "Cannot use FlashAttention-2 backend for block size not " + "divisible by 16.") + selected_backend = _Backend.XFORMERS + elif sliding_window is not None: + logger.info( + "Cannot use FlashAttention-2 backend due to sliding window.") + selected_backend = _Backend.XFORMERS + + # FlashAttn is valid for the model, checking if the package is installed. + if selected_backend == _Backend.FLASH_ATTN: + try: + import vllm_flash_attn # noqa: F401 + + from vllm.attention.backends.flash_attn import ( # noqa: F401 + FlashAttentionBackend) + + supported_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in supported_sizes: + logger.info( + "Cannot use FlashAttention-2 backend for head size %d.", + head_size) + selected_backend = _Backend.XFORMERS + except ImportError: + logger.info( + "Cannot use FlashAttention-2 backend because the " + "vllm_flash_attn package is not found. " + "`pip install vllm-flash-attn` for better performance.") + selected_backend = _Backend.XFORMERS + + return selected_backend From 606625329648e6eff1883e23040adfad82f219cf Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Thu, 23 May 2024 02:39:27 -0400 Subject: [PATCH 103/124] Marlin 24 prefill performance improvement (about 25% better on average) (#4983) --- benchmarks/kernels/benchmark_marlin.py | 74 ++++++++++++++++--- .../marlin/sparse/marlin_24_cuda_kernel.cu | 55 ++++++++++---- tests/kernels/test_marlin_gemm.py | 2 +- .../layers/quantization/gptq_marlin_24.py | 8 +- 4 files changed, 107 insertions(+), 32 deletions(-) diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 5dcffc284f3d4..b771911781574 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -6,9 +6,13 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - MarlinWorkspace, marlin_quantize) + MarlinWorkspace, marlin_24_quantize, marlin_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( gptq_pack, quantize_weights, sort_weights) @@ -44,6 +48,10 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size, marlin_rand_perm, ) = marlin_quantize(b, num_bits, group_size, act_order) + # Marlin_24 quant + (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, + marlin_24_s) = marlin_24_quantize(b, num_bits, group_size) + # GPTQ quant (w_ref, q_w, s, g_idx, rand_perm) = quantize_weights(b, num_bits, group_size, act_order) @@ -56,28 +64,43 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size, (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) # Prepare - marlin_workspace = MarlinWorkspace(size_n) + marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_MAX_PARALLEL) globals = { + # Gen params + "num_bits": num_bits, + "group_size": group_size, + "size_m": size_m, + "size_n": size_n, + "size_k": size_k, + "a": a, + "a_tmp": a_tmp, + # Marlin params "marlin_w_ref": marlin_w_ref, "marlin_q_w": marlin_q_w, "marlin_s": marlin_s, "marlin_g_idx": marlin_g_idx, "marlin_sort_indices": marlin_sort_indices, "marlin_rand_perm": marlin_rand_perm, + "marlin_workspace": marlin_workspace, + "is_k_full": is_k_full, + # Marlin_24 params + "marlin_24_w_ref": marlin_24_w_ref, + "marlin_24_q_w_comp": marlin_24_q_w_comp, + "marlin_24_meta": marlin_24_meta, + "marlin_24_s": marlin_24_s, + "marlin_24_workspace": marlin_24_workspace, + # GPTQ params "q_w_gptq": q_w_gptq, "repack_sort_indices": repack_sort_indices, - "num_bits": num_bits, - "group_size": group_size, - "size_m": size_m, - "size_n": size_n, - "size_k": size_k, - "is_k_full": is_k_full, - "a": a, - "a_tmp": a_tmp, + # Kernels "gptq_marlin_gemm": ops.gptq_marlin_gemm, + "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, "gptq_marlin_repack": ops.gptq_marlin_repack, - "marlin_workspace": marlin_workspace, } min_run_time = 1 @@ -105,6 +128,18 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size, description="gptq_marlin_gemm", ).blocked_autorange(min_run_time=min_run_time)) + if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS + and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): + results.append( + benchmark.Timer( + stmt= + "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_24_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + results.append( benchmark.Timer( stmt= @@ -135,8 +170,20 @@ def main(args): continue for act_order in ACT_ORDER_OPTS: + if len(args.limit_act_order + ) > 0 and act_order not in args.limit_act_order: + continue + for is_k_full in K_FULL_OPTS: + if len(args.limit_k_full + ) > 0 and is_k_full not in args.limit_k_full: + continue + for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS: + if len(args.limit_num_bits + ) > 0 and num_bits not in args.limit_num_bits: + continue + for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: if len( args.limit_group_size @@ -159,7 +206,7 @@ def main(args): # For quick benchmarking use: -# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 # noqa E501 +# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501 # if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -178,6 +225,9 @@ def main(args): parser.add_argument("--limit-k", nargs="+", type=int, default=[]) parser.add_argument("--limit-n", nargs="+", type=int, default=[]) parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[]) + parser.add_argument("--limit-act-order", nargs="+", type=int, default=[]) + parser.add_argument("--limit-k-full", nargs="+", type=int, default=[]) args = parser.parse_args() main(args) diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 54ad27676e207..686dd7851e6af 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -48,12 +48,12 @@ namespace marlin_24 { // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. static constexpr int THREADS = 256; -static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory +static constexpr int STAGES = 4; static constexpr int min_thread_n = 128; static constexpr int tile_size = 16; -static constexpr int max_par = 16; +static constexpr int max_par = 64; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 @@ -736,10 +736,10 @@ __global__ void Marlin_24( for (int pipe = 0; pipe < stages;) { fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + matmul(pipe); wait_for_stage(); fetch_to_registers(pipe + 1, (pipe + 1) % stages); - matmul(pipe); pipe++; slice_iters--; @@ -899,9 +899,12 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, // than better compute utilization thread_k = 128; thread_m = 128; - } else { + } else if (prob_n <= 256) { thread_k = 64; thread_m = 256; + } else { + thread_k = 32; + thread_m = 512; } } @@ -928,19 +931,21 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, int4* C_ptr = (int4*)C; const int4* s_ptr = (const int4*)s; + constexpr int max_m_blocks = 4; + int* locks = (int*)workspace; - for (int i = 0; i < tot_n_blocks; i += 4) { + for (int i = 0; i < tot_n_blocks; i += max_m_blocks) { int thread_n_blocks = tot_n_blocks - i; prob_n = tot_n - 16 * i; int par = 1; - if (thread_n_blocks > 4) { + if (thread_n_blocks > max_m_blocks) { // Note that parallel > 1 currently only works for inputs without any // padding - par = (16 * thread_n_blocks - pad) / 64; + par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16); if (par > max_par) par = max_par; - prob_n = 64 * par; - i += 4 * (par - 1); - thread_n_blocks = 4; + prob_n = (max_m_blocks * 16) * par; + i += max_m_blocks * (par - 1); + thread_n_blocks = max_m_blocks; } // For compilation speed, we only define the kernel configurations that have @@ -951,8 +956,9 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, if (false) { } // BMxBNxBK, group // 4-bit - CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 @@ -962,9 +968,19 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, CALL_IF_2_4(4, 16, 4, 2, -1) CALL_IF_2_4(4, 16, 4, 2, 4) + CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64 + CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64 + CALL_IF_2_4(4, 32, 2, 1, 4) + CALL_IF_2_4(4, 32, 3, 1, -1) + CALL_IF_2_4(4, 32, 3, 1, 4) + CALL_IF_2_4(4, 32, 4, 1, -1) + CALL_IF_2_4(4, 32, 4, 1, 4) + // 8-bit - CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 @@ -973,6 +989,15 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, CALL_IF_2_4(8, 16, 3, 2, 4) CALL_IF_2_4(8, 16, 4, 2, -1) CALL_IF_2_4(8, 16, 4, 2, 4) + + CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64 + CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64 + CALL_IF_2_4(8, 32, 2, 1, 4) + CALL_IF_2_4(8, 32, 3, 1, -1) + CALL_IF_2_4(8, 32, 3, 1, 4) + CALL_IF_2_4(8, 32, 4, 1, -1) + CALL_IF_2_4(8, 32, 4, 1, 4) else { throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]" + @@ -1062,7 +1087,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, int thread_k = -1; int thread_m = -1; int sms = -1; - int max_par = 16; + int max_par = marlin_24::max_par; int groupsize = -1; if (b_scales.size(0) > 1) { diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index 587fc3901eb7c..1f8d94bad26d9 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -27,7 +27,7 @@ MARLIN_N_CHUNKS = [64, 128, 256] MARLIN_24_K_CHUNKS = [128] -MARLIN_24_N_CHUNKS = [256] +MARLIN_24_N_CHUNKS = [512] MNK_FACTORS = [ (1, 1, 1), diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index f5345c0443029..6bcfc405afe71 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -15,7 +15,7 @@ GPTQ_MARLIN_24_TILE = 16 GPTQ_MARLIN_24_MIN_THREAD_N = 128 GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 16 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] @@ -53,14 +53,14 @@ def __init__( self.tile_size = 16 # Min out_features dim - self.min_n_threads = 128 + self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N # Min in_features dim - self.min_k_threads = 128 + self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K # Max parallel problems to solve at once (improves large # batch performance) - self.max_parallel = 16 + self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL # Permutation length used by the marlin kernels. self.perm_len = 1024 From 2ba80bed2732edf42b1014ea4e34757849fc93d0 Mon Sep 17 00:00:00 2001 From: Letian Li Date: Thu, 23 May 2024 17:08:58 +0100 Subject: [PATCH 104/124] [Bugfix] Update Dockerfile.cpu to fix NameError: name 'vllm_ops' is not defined (#5009) --- .buildkite/run-cpu-test.sh | 2 +- Dockerfile.cpu | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index f187d1f181724..414045fe163e5 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -11,4 +11,4 @@ trap remove_docker_container EXIT remove_docker_container # Run the image and launch offline inference -docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py +docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 vllm/examples/offline_inference.py diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 4251fddd6cc3b..aec79824213f3 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -17,4 +17,6 @@ RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.py RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install +WORKDIR /workspace/ + CMD ["/bin/bash"] From 5eda2ea02a01b2457f4d6ac2a217f2fa8a2e5d5f Mon Sep 17 00:00:00 2001 From: Murali Andoorveedu <37849411+andoorve@users.noreply.github.com> Date: Thu, 23 May 2024 09:54:48 -0700 Subject: [PATCH 105/124] [Core][1/N] Support send/recv in PyNCCL Groups (#4988) Signed-off-by: Muralidhar Andoorveedu --- tests/distributed/test_pynccl.py | 75 +++++++++++++++++-- vllm/distributed/communication_op.py | 18 +++-- .../device_communicators/pynccl.py | 34 +++++++++ .../device_communicators/pynccl_wrapper.py | 26 +++++++ vllm/distributed/parallel_state.py | 34 +++++++-- 5 files changed, 170 insertions(+), 17 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 529e75fb2c9e3..0218295a3e3f9 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -3,6 +3,7 @@ import pytest import torch +import torch.distributed from vllm.distributed.communication_op import ( # noqa graph_capture, tensor_model_parallel_all_reduce) @@ -68,7 +69,7 @@ def test_pynccl(): @worker_fn_wrapper -def multiple_tp_worker_fn(): +def multiple_allreduce_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") groups = [ torch.distributed.new_group(ranks=[0, 1], backend="gloo"), @@ -92,14 +93,14 @@ def multiple_tp_worker_fn(): @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test.") -def test_pynccl_multiple_tp(): +def test_pynccl_multiple_allreduce(): # this tests pynccl for multiple tp groups, in a standalone way # i.e. call `pynccl_comm.all_reduce` directly - distributed_run(multiple_tp_worker_fn, 4) + distributed_run(multiple_allreduce_worker_fn, 4) @worker_fn_wrapper -def multiple_tp_with_vllm_worker_fn(): +def multiple_allreduce_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") ensure_model_parallel_initialized(2, 2) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) @@ -118,10 +119,10 @@ def multiple_tp_with_vllm_worker_fn(): @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test.") -def test_pynccl_multiple_tp_with_vllm(): +def test_pynccl_multiple_allreduce_with_vllm(): # this tests pynccl for multiple tp groups, together with vllm # i.e. call `tensor_model_parallel_all_reduce` - distributed_run(multiple_tp_with_vllm_worker_fn, 4) + distributed_run(multiple_allreduce_with_vllm_worker_fn, 4) @worker_fn_wrapper @@ -151,6 +152,68 @@ def test_pynccl_with_cudagraph(): distributed_run(worker_fn_with_cudagraph, 2) +@worker_fn_wrapper +def send_recv_worker_fn(): + pynccl_comm = PyNcclCommunicator() + if pynccl_comm.rank == 0: + tensor = torch.ones(16, 1024, 1024, + dtype=torch.float32).cuda(pynccl_comm.rank) + else: + tensor = torch.empty(16, 1024, 1024, + dtype=torch.float32).cuda(pynccl_comm.rank) + with pynccl_comm.change_state(enable=True): + if pynccl_comm.rank == 0: + pynccl_comm.send(tensor) + else: + pynccl_comm.recv(tensor) + result = tensor.mean().cpu().item() + assert result == 1 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_send_recv(): + distributed_run(send_recv_worker_fn, 2) + + +@worker_fn_wrapper +def multiple_send_recv_worker_fn(): + device = torch.device(f"cuda:{torch.distributed.get_rank()}") + groups = [ + torch.distributed.new_group(ranks=[0, 2], backend="gloo"), + torch.distributed.new_group(ranks=[1, 3], backend="gloo") + ] + group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1] + pynccl_comm = PyNcclCommunicator(group=group, device=device) + if torch.distributed.get_rank() == 0: + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) + elif torch.distributed.get_rank() == 1: + tensor = 2 * torch.ones( + 16, 1024, 1024, dtype=torch.float32, device=device) + else: + tensor = torch.empty(16, + 1024, + 1024, + dtype=torch.float32, + device=device) + with pynccl_comm.change_state(enable=True): + if torch.distributed.get_rank() in [0, 1]: + pynccl_comm.send(tensor) + else: + pynccl_comm.recv(tensor) + result = tensor.mean().cpu().item() + if torch.distributed.get_rank() in [0, 2]: + assert result == 1 + else: + assert result == 2 + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +def test_pynccl_multiple_send_recv(): + distributed_run(multiple_send_recv_worker_fn, 4) + + def test_ncclGetUniqueId(): lib = NCCLLibrary() unique_id = lib.ncclGetUniqueId() diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 937fd4d392713..2b38ec472de66 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -6,7 +6,7 @@ import torch from torch.distributed import ProcessGroup -from .parallel_state import (get_cpu_world_group, +from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -54,13 +54,19 @@ def graph_capture(): # graph, we use either custom all-reduce kernel or PyTorch NCCL. # We always prioritize using custom all-reduce kernel but fall back # to PyTorch or pynccl if it is disabled or not supported. - pynccl_comm = get_tp_pynccl_communicator() - if pynccl_comm is None: - maybe_pynccl_context = nullcontext() + tp_pynccl_comm = get_tp_pynccl_communicator() + pp_pynccl_comm = get_pp_pynccl_communicator() + if not tp_pynccl_comm: + maybe_tp_pynccl_context = nullcontext() else: - maybe_pynccl_context = pynccl_comm.change_state( + maybe_tp_pynccl_context = tp_pynccl_comm.change_state( enable=True, stream=torch.cuda.current_stream()) - with maybe_pynccl_context: + if not pp_pynccl_comm: + maybe_pp_pynccl_context = nullcontext() + else: + maybe_pp_pynccl_context = pp_pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream()) + with maybe_tp_pynccl_context, maybe_pp_pynccl_context: yield graph_capture_context diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 092a0910329ad..f5f1de0c71615 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -126,6 +126,40 @@ def all_reduce(self, ncclRedOpTypeEnum.from_torch(op), self.comm, cudaStream_t(stream.cuda_stream)) + def send(self, + tensor: torch.Tensor, + dst: Optional[int] = None, + stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = self.stream + if dst is None: + dst = (self.rank + 1) % self.world_size + self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), dst, + self.comm, cudaStream_t(stream.cuda_stream)) + + def recv(self, + tensor: torch.Tensor, + src: Optional[int] = None, + stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = self.stream + if src is None: + src = (self.rank - 1) % self.world_size + self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) + @contextmanager def change_state(self, enable: Optional[bool] = None, diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 43d85674b23d0..3aa3744d0d827 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -151,6 +151,22 @@ class NCCLLibrary: ncclRedOp_t, ncclComm_t, cudaStream_t ]), + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function("ncclSend", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function("ncclRecv", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + # be cautious! this is a collective call, it will block until all # processes in the communicator have called this function. # because Python object destruction can happen in random order, @@ -248,6 +264,16 @@ def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, datatype, op, comm, stream)) + def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, + dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, + dest, comm, stream)) + + def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, + src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, + comm, stream)) + def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d24104e3ed276..0ebd7a15eab9b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -22,6 +22,8 @@ _TP_CA_COMMUNICATOR = None # Pipeline model parallel group that the current rank belongs to. _PP_DEVICE_GROUP: Optional[ProcessGroup] = None +_PP_CPU_GROUP: Optional[ProcessGroup] = None +_PP_PYNCCL_COMMUNICATOR = None # when people blindly call `torch.distributed.all_reduce` etc, # it will use this group. It is initialized with the `backend` @@ -55,6 +57,11 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable +def get_pp_pynccl_communicator(): + global _PP_PYNCCL_COMMUNICATOR + return _PP_PYNCCL_COMMUNICATOR + + def get_tp_pynccl_communicator(): global _TP_PYNCCL_COMMUNICATOR return _TP_PYNCCL_COMMUNICATOR @@ -180,10 +187,11 @@ def initialize_model_parallel( _TP_CPU_GROUP = cpu_group from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator - _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( - group=_TP_CPU_GROUP, - device=_LOCAL_RANK, - ) + if tensor_model_parallel_size > 1: + _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( + group=_TP_CPU_GROUP, + device=_LOCAL_RANK, + ) # Initialize a custom fast all-reduce implementation. if _ENABLE_CUSTOM_ALL_REDUCE: @@ -195,17 +203,26 @@ def initialize_model_parallel( ) # Build the pipeline model-parallel groups. - global _PP_DEVICE_GROUP + global _PP_DEVICE_GROUP, _PP_CPU_GROUP + global _PP_PYNCCL_COMMUNICATOR global _PP_GLOBAL_RANKS assert _PP_DEVICE_GROUP is None, ( "pipeline model parallel group is already initialized") for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group = torch.distributed.new_group(ranks, backend=backend) + cpu_group = torch.distributed.new_group(ranks, backend="gloo") if rank in ranks: _PP_DEVICE_GROUP = group + _PP_CPU_GROUP = cpu_group _PP_GLOBAL_RANKS = ranks + if pipeline_model_parallel_size > 1: + _PP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( + group=_PP_CPU_GROUP, + device=_LOCAL_RANK, + ) + def ensure_model_parallel_initialized( tensor_model_parallel_size: int, @@ -267,6 +284,13 @@ def get_pipeline_model_parallel_group(): return _PP_DEVICE_GROUP +def get_pipeline_model_parallel_cpu_group(): + """Get the pipeline model parallel cpu group the caller rank belongs to.""" + assert _PP_CPU_GROUP is not None, ( + "pipeline model parallel cpu group is not initialized") + return _PP_CPU_GROUP + + def get_tensor_model_parallel_world_size(): """Return world size for the tensor model parallel group.""" return torch.distributed.get_world_size( From a1242324c99ff8b1e29981006dfb504da198c7c3 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 23 May 2024 17:29:18 -0400 Subject: [PATCH 106/124] [Kernel] Initial Activation Quantization Support (#4525) Co-authored-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- CMakeLists.txt | 1 + csrc/ops.h | 3 + csrc/pybind.cpp | 3 + .../compressed_tensors/int8_quant_kernels.cu | 59 +++++ tests/kernels/test_int8_quant.py | 31 +++ tests/quantization/test_compressed_tensors.py | 36 +++ vllm/_custom_ops.py | 18 ++ vllm/model_executor/layers/linear.py | 244 ++++++++++++------ .../layers/quantization/__init__.py | 3 + .../compressed_tensors/__init__.py | 0 .../compressed_tensors/compressed_tensors.py | 151 +++++++++++ .../compressed_tensors/schemes/__init__.py | 5 + .../schemes/compressed_tensors_scheme.py | 33 +++ .../schemes/compressed_tensors_unquantized.py | 39 +++ .../compressed_tensors_w8a8_statictensor.py | 119 +++++++++ .../model_loader/weight_utils.py | 7 + vllm/model_executor/models/llama.py | 25 +- 17 files changed, 683 insertions(+), 94 deletions(-) create mode 100644 csrc/quantization/compressed_tensors/int8_quant_kernels.cu create mode 100644 tests/kernels/test_int8_quant.py create mode 100644 tests/quantization/test_compressed_tensors.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/__init__.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 35846fd1cfa99..b668cbc97de15 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" diff --git a/csrc/ops.h b/csrc/ops.h index f5e0e423bb65d..b839eaf0d26c8 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -93,6 +93,9 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, #endif +void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input, + float scale); + void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor lookup_table); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index cba07f0ae9f2a..cdbec4a34d77f 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -67,6 +67,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Aligning the number of tokens to be processed by each expert such " "that it is divisible by the block size."); + ops.def("static_scaled_int8_quant", &static_scaled_int8_quant, + "Compute int8 quantized tensor for given scaling factor"); + // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); cache_ops.def("swap_blocks", &swap_blocks, diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu new file mode 100644 index 0000000000000..4902e4c23434c --- /dev/null +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -0,0 +1,59 @@ +#include +#include +#include + +#include "../../dispatch_utils.h" + +static inline __device__ int8_t float_to_int8_rn(float x) { +#ifdef USE_ROCM + static const float i8_min = + static_cast(std::numeric_limits::min()); + static const float i8_max = + static_cast(std::numeric_limits::max()); + // round + float dst = std::nearbyint(x); + // saturate + dst = std::clamp(dst, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +namespace vllm { + +template +__global__ void static_scaled_int8_quant_kernel( + const scalar_t* __restrict__ input, int8_t* __restrict__ out, + scale_type scale, const int hidden_size) { + const int tid = threadIdx.x; + const int token_idx = blockIdx.x; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = + float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale); + } +} +} // namespace vllm + +void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + float scale) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { + vllm::static_scaled_int8_quant_kernel + <<>>(input.data_ptr(), + out.data_ptr(), scale, + hidden_size); + }); +} diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py new file mode 100644 index 0000000000000..b9aa00ce13f56 --- /dev/null +++ b/tests/kernels/test_int8_quant.py @@ -0,0 +1,31 @@ +import pytest +import torch + +from vllm._C import ops + +DTYPES = [torch.half, torch.bfloat16, torch.float] +HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192] # Arbitrary values for testing +NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +SEEDS = [0] +SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE) +@torch.inference_mode() +def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, + seed: int, scale: float) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + + out1 = (x / scale).round().clamp( + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + ops.static_scaled_int8_quant(out2, x, scale) + assert torch.allclose(out1, out2, + atol=1) # big atol to account for rounding errors diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py new file mode 100644 index 0000000000000..b83286992da3d --- /dev/null +++ b/tests/quantization/test_compressed_tensors.py @@ -0,0 +1,36 @@ +"""Test model set-up and weight loading for sparseml-quantized models. + +Run `pytest tests/quantization/test_compressed_tensors.py`. +""" + +import torch + +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor) + + +def test_compressed_tensors_w8a8_static_setup(vllm_runner): + model_path = "nm-testing/tinyllama-one-shot-static-quant-test-compressed" + llm = vllm_runner(model_path, quantization="sparseml", enforce_eager=True) + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj + + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(gate_up_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod) + + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor) + + assert qkv_proj.weight.dtype is torch.int8 + assert o_proj.weight.dtype is torch.int8 + assert gate_up_proj.weight.dtype is torch.int8 + + assert qkv_proj.weight_scale.shard_splitter is not None + assert qkv_proj.weight_scale.logical_widths is not None + assert qkv_proj.input_scale.dtype is torch.float32 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9e7d0d96bf004..f0fab4d8aa26d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -251,6 +251,24 @@ def scaled_fp8_quant( return output, scale +# int8 +def static_scaled_int8_quant(input: torch.Tensor, + scale: float) -> torch.Tensor: + """ + Quantize the input tensor to int8 and return the quantized tensor. + + Args: + input: The input tensor to be quantized to int8. + scale: Scaling factor for the int8 quantization. + + Returns: + torch.Tensor: Output tensor in int8. + """ + q = torch.empty_like(input, dtype=torch.int8) + vllm_ops.static_scaled_int8_quant(q, input, scale) + return q + + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 7726dcb9a5fbd..34fbfa8e33ef9 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -56,7 +56,6 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: """Apply the weights in layer to the input tensor. - Expects create_weights to have been called before on the layer.""" raise NotImplementedError @@ -77,8 +76,7 @@ def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - output_size_per_partition = sum(output_partition_sizes) - weight = Parameter(torch.empty(output_size_per_partition, + weight = Parameter(torch.empty(sum(output_partition_sizes), input_size_per_partition, dtype=params_dtype), requires_grad=False) @@ -149,15 +147,13 @@ class ReplicatedLinear(LinearBase): quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -210,17 +206,15 @@ class ColumnParallelLinear(LinearBase): the list would be size 3. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - output_sizes: Optional[List[int]] = None, - ): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[List[int]] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -228,18 +222,26 @@ def __init__( # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, tp_size) + assert self.quant_method is not None + self.output_size_per_partition = divide(self.output_size, tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, tp_size) + for output_size in self.output_sizes + ] + if output_sizes is None: output_sizes = [output_size] - # All the linear layer supports quant method. - assert self.quant_method is not None - self.quant_method.create_weights(self, - self.input_size, - [x // tp_size for x in output_sizes], - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader) + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -317,22 +319,24 @@ class MergedColumnParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_sizes: List[int], - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) - super().__init__(input_size, sum(output_sizes), bias, gather_output, - skip_bias_add, params_dtype, quant_config, - self.output_sizes) + super().__init__(input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config) def weight_loader(self, param: Parameter, @@ -343,6 +347,26 @@ def weight_loader(self, output_dim = getattr(param, "output_dim", None) # Special case for AQLM codebooks. is_metadata = getattr(param, "is_metadata", False) + + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # If a parameter has defined a shard_splitter to be used for + # the weight, it should be applied before the weight is + # loaded/copied to the parameter. The shard_splitter applies + # logic by using the loaded_shard_id to ensure that the loaded + # param is loaded to the correct location + # within the parameter defined by the linear method. + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # Special case for Fp8 scales. fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", None) @@ -403,6 +427,13 @@ def weight_loader(self, shard_size = loaded_weight.shape[0] shard_offset = loaded_shard_id * shard_size param_data = param_data.narrow(0, shard_offset, shard_size) + + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths", None) + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + # Special case for Fp8 scales. elif fp8_scales_shard_indexer is not None: param_data, loaded_weight = fp8_scales_shard_indexer( @@ -415,6 +446,14 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") + + if fp8_scales_shard_indexer is None: + if len(param_data.shape) == 0: + param_data = param_data.reshape(1) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -443,17 +482,15 @@ class QKVParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. """ - def __init__( - self, - hidden_size: int, - head_size: int, - total_num_heads: int, - total_num_kv_heads: Optional[int] = None, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -473,14 +510,19 @@ def __init__( input_size = self.hidden_size output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size - output_sizes = [ - self.num_heads * tp_size * self.head_size, - self.num_kv_heads * tp_size * self.head_size, - self.num_kv_heads * tp_size * self.head_size + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj ] - super().__init__(input_size, output_size, bias, False, skip_bias_add, - params_dtype, quant_config, output_sizes) + super().__init__(input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config) def weight_loader(self, param: Parameter, @@ -490,6 +532,26 @@ def weight_loader(self, output_dim = getattr(param, "output_dim", None) # Special case for AQLM codebooks. is_metadata = getattr(param, "is_metadata", False) + + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # If a parameter has defined a shard_splitter to be used for + # the weight, it should be applied before the weight is + # loaded/copied to the parameter. The shard_splitter applies + # logic by using the loaded_shard_id to ensure that the loaded + # param is loaded to the correct location + # within the parameter defined by the linear method. + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # Special case for Fp8 scales. fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", None) @@ -528,6 +590,8 @@ def weight_loader(self, tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. if output_dim is not None: if loaded_shard_id == "q": shard_offset = 0 @@ -567,6 +631,12 @@ def weight_loader(self, shard_index = ["q", "k", "v"].index(loaded_shard_id) param_data = param_data.narrow(0, shard_index * shard_size, shard_size) + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths", None) + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + # Special case for Fp8 scales. elif fp8_scales_shard_indexer is not None: param_data, loaded_weight = fp8_scales_shard_indexer( @@ -578,6 +648,13 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "QKVParallelLinear, assume the weight is the same " "for all partitions.") + + if len(param_data.shape) == 0: + param_data = param_data.reshape(1) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -608,17 +685,15 @@ class RowParallelLinear(LinearBase): quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -628,16 +703,15 @@ def __init__( # Divide the weight matrix along the last dimension. self.tp_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = divide(input_size, self.tp_size) - # All the linear layer supports quant method. assert self.quant_method is not None - self.quant_method.create_weights(self, - self.input_size_per_partition, - [self.output_size], - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader) - + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=[self.output_size], + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " "results can lead to incorrect results") @@ -665,12 +739,16 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + # Special case for Fp8 scales. elif fp8_scales_shard_indexer is not None: param_data, loaded_weight = fp8_scales_shard_indexer(param_data, loaded_weight, shard_id=0) + if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index f938e7d37ec5f..7b9abe1b629a1 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,6 +4,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig) from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig) from vllm.model_executor.layers.quantization.fp8 import Fp8Config @@ -27,6 +29,7 @@ "gptq_marlin": GPTQMarlinConfig, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, + "sparseml": CompressedTensorsConfig, } diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py new file mode 100644 index 0000000000000..19e464bd64325 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -0,0 +1,151 @@ +from typing import Any, Dict, List, Optional + +import torch + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme, CompressedTensorsW8A8StaticTensor) + + +class CompressedTensorsConfig(QuantizationConfig): + + def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]): + self.ignore = ignore + self.layer_quant_details = layer_quant_details + + def get_linear_method(self) -> "CompressedTensorsLinearMethod": + return CompressedTensorsLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16] + + # Need to figure it out + def get_min_capability(self) -> int: + return 60 + + def get_name(self) -> str: + return "compressed_tensors" + + def get_quant_method( + self, layer: torch.nn.Module + ) -> Optional["CompressedTensorsLinearMethod"]: + if isinstance(layer, LinearBase): + return CompressedTensorsLinearMethod(self) + return None + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": + layer_quant_details: Dict[str, Any] = dict() + ignore: List[str] = config.get("ignore", None) + + for key, quant_config in config["config_groups"].items(): + targets = quant_config.get("targets") + for target in targets: + layer_quant_details[target] = {} + layer_quant_details[target]["weight"] = quant_config.get( + "weights") + layer_quant_details[target]["input"] = quant_config.get( + "input_activations") + + return cls(layer_quant_details=layer_quant_details, ignore=ignore) + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + def _get_schema(self, weight_quant: Dict, input_quant: Dict): + # TODO: Refactor as additional cases are supported + + weight_bit = weight_quant.get("num_bits") + input_bit = input_quant.get("num_bits") + + weight_strategy = weight_quant.get("strategy") + input_strategy = input_quant.get("strategy") + + weight_symmetric = weight_quant.get("symmetric") + input_symmetric = input_quant.get("symmetric") + + is_8_bits = weight_bit == input_bit == 8 + is_tensor = weight_strategy == input_strategy == "tensor" + is_symmetric = weight_symmetric and input_symmetric + + if is_8_bits and is_tensor and is_symmetric and \ + torch.cuda.is_available(): + # CompressedTensorsW8A8StaticTensor only supports CUDA path for + # now. + return CompressedTensorsW8A8StaticTensor() + raise NotImplementedError( + "Scheme not supported. Only CUDA, 8-bit static symmtetric " + "per tensor quantization is currently supported") + + def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": + + # TODO: update with matching function from `compressed_tensors` + layer_type_name = None + layer_name_class = type(layer).__name__.lower() + for target in self.layer_quant_details: + if target.lower() in layer_name_class: + layer_type_name = target + break + if layer_type_name is None: + raise ValueError(f"Could not matching target for layer {layer}") + + layer_quant_details: Dict[str, Any] = self.layer_quant_details.get( + layer_type_name, None) + if layer_quant_details is None: + raise ValueError( + f"Could not find quantization details for {layer}.") + + return self._get_schema(weight_quant=layer_quant_details["weight"], + input_quant=layer_quant_details["input"]) + + +class CompressedTensorsLinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: CompressedTensorsConfig): + self.quantization_config = quantization_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """ + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. + """ + weight_loader = extra_weight_attrs.get("weight_loader") + + scheme = self.quantization_config.get_scheme(layer=layer) + scheme.create_weights( + layer=layer, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader) + + layer.scheme = scheme + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """ + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. + """ + + if bias is not None: + raise ValueError("bias is not supported for this linear method") + + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py new file mode 100644 index 0000000000000..831905b63e2c9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -0,0 +1,5 @@ +from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401 +from .compressed_tensors_unquantized import ( # noqa: F401 + CompressedTensorsUnquantized) +from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501 + CompressedTensorsW8A8StaticTensor) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py new file mode 100644 index 0000000000000..3a5904208656e --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -0,0 +1,33 @@ +from abc import ABC, abstractmethod + +import torch + +__all__ = ["CompressedTensorsScheme"] + + +class CompressedTensorsScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass + of different quantization schemes supported by CompressedTensors. + """ + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + """ + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. + + :param layer: toch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + :param x: input to the layer + + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py new file mode 100644 index 0000000000000..0cfac13d1ca25 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py @@ -0,0 +1,39 @@ +from typing import Callable, List + +import torch +import torch.nn.functional as F +from torch.nn import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.utils import set_weight_attrs + +__all__ = ["CompressedTensorsUnquantized"] + + +class CompressedTensorsUnquantized(CompressedTensorsScheme): + """ + Implements the scheme for all layers which are ignored + in the CompressedTensors config. The input and loaded weight are used + in a linear transformation. + """ + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + device="cuda", + dtype=params_dtype), + requires_grad=False) + + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, {"weight_loader": weight_loader}) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + weight = layer.weight + return F.linear(x, weight) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py new file mode 100644 index 0000000000000..d16e570d12202 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -0,0 +1,119 @@ +from typing import Callable, List, Tuple, Union + +import torch +from torch.nn import Parameter + +from vllm import _custom_ops as custom_ops +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.utils import set_weight_attrs + +__all__ = ["CompressedTensorsW8A8StaticTensor"] + + +class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): + + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + assert isinstance(shard_id, str) + qkv_idxs = {"q": 0, "k": 1, "v": 2} + assert shard_id in qkv_idxs + return qkv_idxs[shard_id] + + def scales_shard_splitter( + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Union[str, int], + logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + shard_id = self._shard_id_as_int(shard_id) + offset = sum(logical_widths[:shard_id]) + size = logical_widths[shard_id] + # update loaded weight with copies for broadcast. + loaded_weight = loaded_weight.repeat(size) + return param[offset:offset + size], loaded_weight + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + # TODO: remove zero_point parameters once the configs given remove them + + # Note on input/weight scales and zero_points + # + # When the scales have a single value, it is required that they be + # on the CPU for 2 reasons, + # 1. Performance: + # When the scales (input_scale/weight_scales) have only a single + # value, we perform a scalar broadcast of that value during the + # quant/dequant operations. The "quant" and the "gemm+dequant" + # kernels accept the Scalar by-value. These tensors are allocated + # on the CPU in order to avoid the GPU-to-CPU copy when passing + # by-value. + # + # 2. CUDA Graphs: + # CUDA Graphs don't support GPU-to-CPU copy operations during + # stream capture. + # + # TODO: zero-points are not supported yet. But we expect a similar + # pattern. + + is_tensor_partitioned = len(output_partition_sizes) != 1 + weight_scale_dim = sum( + output_partition_sizes) if is_tensor_partitioned else 1 + weight_scale_device = "cpu" if weight_scale_dim == 1 else "cuda" + + input_scale = Parameter(torch.empty(1, + device="cpu", + dtype=torch.float32), + requires_grad=False) + input_zero_point = Parameter(torch.empty(1, + device="cpu", + dtype=torch.int8), + requires_grad=False) + + weight_scale = Parameter(torch.empty(weight_scale_dim, + device=weight_scale_device, + dtype=torch.float32), + requires_grad=False) + weight_zero_point = Parameter(torch.empty(1, + device="cpu", + dtype=torch.int8), + requires_grad=False) + + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + requires_grad=False) + + layer.register_parameter("weight", weight) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + + set_weight_attrs(weight, {"weight_loader": weight_loader}) + + layer.register_parameter("input_scale", input_scale) + set_weight_attrs(input_scale, {"weight_loader": weight_loader}) + layer.register_parameter("input_zero_point", input_zero_point) + set_weight_attrs(input_zero_point, {"weight_loader": weight_loader}) + layer.register_parameter("weight_scale", weight_scale) + set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) + set_weight_attrs( + weight_scale, { + "shard_splitter": self.scales_shard_splitter, + "logical_widths": output_partition_sizes + }) + layer.register_parameter("weight_zero_point", weight_zero_point) + set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader}) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + weight = layer.weight + weight_scale = layer.weight_scale + act_scale = layer.input_scale + + # Input quantize + x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item()) + + return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, + weight_scale, x.dtype) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index a1642baa2c90c..f9b1dc60dd006 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -120,6 +120,13 @@ def get_quant_config(model_config: ModelConfig, # Read the quantization config from the HF model config, if available. hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) + if hf_quant_config is None: + compression_config = getattr(model_config.hf_config, + "compression_config", None) + if compression_config is not None: + hf_quant_config = compression_config.get("quantization_config", + None) + if hf_quant_config is not None: return quant_cls.from_config(hf_quant_config) model_name_or_path = model_config.model diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f43a40a0bfd34..086f9294c4f1c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -62,11 +62,12 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, bias=bias, quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, bias=bias, quant_config=quant_config) if hidden_act != "silu": @@ -120,16 +121,16 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=bias, quant_config=quant_config, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, bias=bias, quant_config=quant_config, ) @@ -263,8 +264,10 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) + LlamaDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) From e3470f87538ec86d1094ac4747519c6300213088 Mon Sep 17 00:00:00 2001 From: Elisei Smirnov <61423871+kezouke@users.noreply.github.com> Date: Fri, 24 May 2024 01:04:24 +0300 Subject: [PATCH 107/124] [Core]: Option To Use Prompt Token Ids Inside Logits Processor (#4985) Co-authored-by: Elisei Smirnov --- vllm/model_executor/layers/logits_processor.py | 17 ++++++++++++++--- vllm/sampling_params.py | 15 ++++++++++----- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 91eb96998c3cf..d450c46455d49 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -1,4 +1,5 @@ """A layer that compute logits from hidden_stats.""" +import inspect from typing import Optional import torch @@ -95,15 +96,25 @@ def _apply_logits_processors( seq_ids = seq_group.seq_ids sampling_params = seq_group.sampling_params logits_processors = sampling_params.logits_processors - if logits_processors: found_logits_processors = True + for seq_id, logits_row_idx in zip(seq_ids, seq_group.sample_indices): logits_row = logits[logits_row_idx] - token_ids = seq_group.seq_data[seq_id].output_token_ids + past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids + prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids + for logits_processor in logits_processors: - logits_row = logits_processor(token_ids, logits_row) + parameters = inspect.signature(logits_processor).parameters + if len(parameters) == 3: + logits_row = logits_processor(prompt_tokens_ids, + past_tokens_ids, + logits_row) + else: + logits_row = logits_processor(past_tokens_ids, + logits_row) + logits[logits_row_idx] = logits_row logits_processed += len(seq_group.sample_indices) + len( diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 5fa94eb149ffb..9d8a361353e26 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -18,10 +18,14 @@ class SamplingType(IntEnum): BEAM = 3 -LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor] -"""LogitsProcessor is a function that takes a list of previously generated -tokens and a tensor of the logits for the next token, and returns a modified -tensor of logits to sample from.""" +LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor], + Callable[[List[int], List[int], torch.Tensor], + torch.Tensor]] +"""LogitsProcessor is a function that takes a list +of previously generated tokens, the logits tensor +for the next token and, optionally, prompt tokens as a +first argument, and returns a modified tensor of logits +to sample from.""" class SamplingParams: @@ -95,7 +99,8 @@ class SamplingParams: spaces_between_special_tokens: Whether to add spaces between special tokens in the output. Defaults to True. logits_processors: List of functions that modify logits based on - previously generated tokens. + previously generated tokens, and optionally prompt tokens as + a first argument. truncate_prompt_tokens: If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). Defaults to None (i.e., no truncation). From 6a50f4cafaf9f734b3f6ad11e6af38838aa3baf8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 23 May 2024 16:21:54 -0700 Subject: [PATCH 108/124] [Doc] add ccache guide in doc (#5012) Co-authored-by: Michael Goin --- docs/source/getting_started/installation.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index 0c81f7ec6d2a9..ba23e7468dcc1 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -56,6 +56,10 @@ You can also build and install vLLM from source: $ # export VLLM_INSTALL_PUNICA_KERNELS=1 # optionally build for multi-LoRA capability $ pip install -e . # This may take 5-10 minutes. +.. tip:: + + Building from source requires quite a lot compilation. If you are building from source for multiple times, it is beneficial to cache the compilation results. For example, you can install `ccache `_ via either `conda install ccache` or `apt install ccache` . As long as `which ccache` command can find the `ccache` binary, it will be used automatically by the build system. After the first build, the subsequent builds will be much faster. + .. tip:: To avoid your system being overloaded, you can limit the number of compilation jobs to be run simultaneously, via the environment variable `MAX_JOBS`. For example: From 919770957f26d71a5a6eda7a1a7443dfeb5ba0ee Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Fri, 24 May 2024 14:28:27 +0200 Subject: [PATCH 109/124] [Bugfix] Fix Mistral v0.3 Weight Loading (#5005) Co-authored-by: Cody Yu --- tests/models/test_mistral.py | 1 + vllm/model_executor/model_loader/loader.py | 17 ++++- .../model_loader/weight_utils.py | 64 ++++++++++++++++++- 3 files changed, 79 insertions(+), 3 deletions(-) diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index d0a5bfbfcd922..76b248cf14e98 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -8,6 +8,7 @@ MODELS = [ "mistralai/Mistral-7B-Instruct-v0.1", + "mistralai/Mistral-7B-Instruct-v0.3", ] diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 45ea8160a801b..b7b5b5e7695f4 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -23,7 +23,8 @@ from vllm.model_executor.model_loader.utils import (get_model_architecture, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( - download_weights_from_hf, filter_files_not_needed_for_inference, + download_safetensors_index_file_from_hf, download_weights_from_hf, + filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.models.vlm_base import VisionLanguageModelBase @@ -188,7 +189,19 @@ def _prepare_weights(self, model_name_or_path: str, use_safetensors = True break - if not use_safetensors: + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, self.load_config.download_dir, + revision) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder) + else: hf_weights_files = filter_files_not_needed_for_inference( hf_weights_files) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index f9b1dc60dd006..53e21eba8fae3 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -12,9 +12,10 @@ import huggingface_hub.constants import numpy as np import torch -from huggingface_hub import HfFileSystem, snapshot_download +from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import LoadConfig, ModelConfig from vllm.logger import init_logger @@ -218,6 +219,67 @@ def download_weights_from_hf( return hf_folder +def download_safetensors_index_file_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + revision: Optional[str] = None, +) -> None: + """Download hf safetensors index file from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + revision (Optional[str]): The revision of the model. + """ + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + try: + # Download the safetensors index file. + hf_hub_download( + repo_id=model_name_or_path, + filename=SAFE_WEIGHTS_INDEX_NAME, + cache_dir=cache_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + # If file not found on remote or locally, we should not fail since + # only some models will have SAFE_WEIGHTS_INDEX_NAME. + except huggingface_hub.utils.EntryNotFoundError: + logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME) + except huggingface_hub.utils.LocalEntryNotFoundError: + logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME) + + +# For models like Mistral-7B-v0.3, there are both sharded +# safetensors files and a consolidated safetensors file. +# Passing both of these to the weight loader functionality breaks. +# So, we use the SAFE_WEIGHTS_INDEX_NAME to +# look up which safetensors files should be used. +def filter_duplicate_safetensors_files(hf_weights_files: List[str], + hf_folder: str) -> List[str]: + # model.safetensors.index.json is a mapping from keys in the + # torch state_dict to safetensors file holding that weight. + index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME) + if not os.path.isfile(index_file_name): + return hf_weights_files + + # Iterate through the weight_map (weight_name: safetensors files) + # to identify weights that we should use. + with open(index_file_name) as index_file: + weight_map = json.load(index_file)["weight_map"] + weight_files_in_index = set() + for weight_name in weight_map: + weight_files_in_index.add( + os.path.join(hf_folder, weight_map[weight_name])) + # Filter out any fields that are not found in the index file. + hf_weights_files = [ + f for f in hf_weights_files if f in weight_files_in_index + ] + return hf_weights_files + + def filter_files_not_needed_for_inference( hf_weights_files: List[str]) -> List[str]: """ From e64fde4b013cb8bb2321f59ba78aca50b02071cb Mon Sep 17 00:00:00 2001 From: leiwen83 Date: Sat, 25 May 2024 01:07:09 +0800 Subject: [PATCH 110/124] [Core][Bugfix]: fix prefix caching for blockv2 (#4764) Co-authored-by: Lei Wen --- tests/core/block/test_prefix_caching_block.py | 117 ++++++++++++++++++ vllm/core/block/prefix_caching_block.py | 41 +++--- 2 files changed, 141 insertions(+), 17 deletions(-) diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index c4c680e109a84..bcf08cda09f46 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -410,6 +410,123 @@ def test_get_common_computed_block_ids(num_blocks: int, block_size: int, assert (len(res) == zero_point_blocks) + # Test case that assume those prompted block after first immutable would + # be freed into hashless allocator, while first immutable block get ref + # increased. + @staticmethod + @pytest.mark.parametrize("num_blocks", [3]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(10))) + def test_alloc_promotion(num_blocks: int, block_size: int, seed: int): + random.seed(seed) + + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + token_ids = list(range(block_size)) + + block = allocator.allocate_immutable(prev_block=None, + token_ids=token_ids) + + assert allocator._refcounter.get(block.block_id) == 1 + m = allocator.allocate_mutable(prev_block=None) + + block_id = m.block_id + for i in range(block_size): + m.append_token_ids([i]) + # After block get promoted to immutable from mutable, if there is + # already same content hash block, then it shall be released into + # hashless_allocator + # And first immutable block's ref get increased by 1 + assert m.block_id == block.block_id + assert block_id in allocator._hashless_allocator._free_block_indices + assert allocator._refcounter.get(block.block_id) == 2 + + # Test case when eviction and allocation are mixed, + # make sure they work as expected + @staticmethod + @pytest.mark.parametrize("num_blocks", [3]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(10))) + def test_eviction_alloc_mixed(num_blocks: int, block_size: int, seed: int): + random.seed(seed) + + all_blocks_list = [i for i in range(num_blocks)] + zero_ref = {i: 0 for i in range(num_blocks)} + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + token_ids = list(range(num_blocks * block_size)) + + # now we have num_blocks free blocks in hashless allocator + # with internal tracking list _blocks _cached_blocks and evictor + # empty and block's ref shall be 0 + assert list(allocator._hashless_allocator._free_block_indices + ) == all_blocks_list + assert len(allocator._blocks.keys()) == 0 + assert len(allocator._cached_blocks.values()) == 0 + assert len(allocator.evictor.free_table.keys()) == 0 + assert allocator._refcounter._refcounts == zero_ref + + # Allocate immutable chains with only one block residuled in + new_block = [] + for i in range(num_blocks): + block = allocator.allocate_immutable( + prev_block=None, + token_ids=token_ids[block_size * i:block_size * (i + 1)]) + new_block.append(block) + + # Free all blocks, and now all blocks shall be in the evictor + # there shall be no tracking data left in _blocks + # all blocks shall be tracked in _cached_blocks + # all blocks' ref shall be zero + for block in new_block: + allocator.free(block) + + assert len(allocator._blocks.keys()) == 0 + assert len(allocator._hashless_allocator._free_block_indices) == 0 + assert list(allocator._cached_blocks.values()) == all_blocks_list + assert list(allocator.evictor.free_table.keys()) == all_blocks_list + assert allocator._refcounter._refcounts == zero_ref + + # Allocate a mutable block, and the first block shall be evicted + # and set its content hash into None, ref to 1 + mutable = allocator.allocate_mutable(prev_block=None) + + assert mutable.block_id == 0 + assert mutable.content_hash is None + assert 0 in allocator._blocks + assert allocator._refcounter.get(0) == 1 + assert 0 not in allocator._cached_blocks + assert 0 not in allocator.evictor + + # Since this mutable block has no hash yet, it shall be released into + # hashless allocator + allocator.free(mutable) + + assert len(allocator._blocks.keys()) == 0 + assert allocator._refcounter._refcounts == zero_ref + assert 0 not in allocator._cached_blocks + assert 0 not in allocator.evictor + assert 0 in allocator._hashless_allocator._free_block_indices + + # when allocate immutable with first block_size tokens, we + # shall get free block from hashless allocator, thus no block left + # in hashless + block = allocator.allocate_immutable(prev_block=None, + token_ids=token_ids[:block_size]) + + assert block.block_id == 0 + assert len(allocator._hashless_allocator._free_block_indices) == 0 + assert 0 in allocator._blocks + assert 0 in allocator._cached_blocks.values() + assert allocator._refcounter.get(0) == 1 + assert 0 not in allocator.evictor + + # allocate mutable block again, it shall be popped from evictor + mutable = allocator.allocate_mutable(prev_block=None) + assert len(allocator._hashless_allocator._free_block_indices) == 0 + assert mutable.block_id not in allocator.evictor.free_table + assert allocator._refcounter.get(mutable.block_id) == 1 + # Test case where two last accessed times are equal @staticmethod @pytest.mark.parametrize("num_blocks", [1024]) diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 882f301c1f697..4eb32f145b05b 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -160,21 +160,17 @@ def allocate_mutable(self, # If the evictor has blocks available for eviction, evict a block # and return it. if self.evictor.num_blocks > 0: + # here we get an evicted block, which is only added + # into evictor if its ref counter is 0 + # and since its content would be changed, we need + # to remove it from _cached_blocks's tracking list block_id, content_hash_to_evict = self.evictor.evict() - # Here we may have scenario that several blocks have - # the same content hash, but due to the latter coming block - # is coming from mutable to immutable path, their physical - # block is added into evictor. - # However in this case, we shall not pop the _cached_blocks, - # as the same content is still used by others, which means - # we need to check ref before decide to pop the list. - _block_id = self._cached_blocks[content_hash_to_evict] - refcount = self._refcounter.get(_block_id) - if refcount == 1: - self._cached_blocks.pop(content_hash_to_evict) - assert _block_id == block_id + assert self._refcounter.get(_block_id) == 0 + assert _block_id == block_id + + self._cached_blocks.pop(content_hash_to_evict) self._refcounter.incr(block_id) @@ -199,7 +195,11 @@ def allocate_mutable(self, def _incr_refcount_cached_block(self, block: Block, block_id: BlockId) -> None: - # since block is already computed, mark it + # now _incr_refcount_cached_block comes from two place + # allocate_immutable/promote_to_immutable_block where hit + # _cached_blocks hash key. + # In both cases, it means that already exists a already + # computed block which shared with block now block.computed = True refcount = self._refcounter.incr(block_id) @@ -228,13 +228,19 @@ def _free_block_id_for_block(self, block_id: BlockId, block: Block) -> None: assert isinstance(block, PrefixCachingBlock) - if block.content_hash is None: + # if we comes from promote_to_immutable_block, it means that + # block.content_hash is never None. + # However we need to release the same content block, so that + # physical block could get reused. + if block.block_id != block_id or block.content_hash is None: refcount = self._refcounter.get(block_id) # We have fork case where block would get more than one ref, # so we cannot free it from tracking if ref cnt large than 1 - if refcount <= 1: - assert block.block_id is not None + assert block.block_id is not None + refcount = self._refcounter.get(block.block_id) + if refcount == 1: del self._blocks[block.block_id] + return self._hashless_allocator.free(block) refcount = self._refcounter.decr(block_id) @@ -317,7 +323,8 @@ def promote_to_immutable_block(self, block: Block) -> BlockId: if block.content_hash not in self._cached_blocks: self._cached_blocks[block.content_hash] = block.block_id else: - self._free_block_id_for_block(block.block_id, block) + self._free_block_id_for_block( + self._cached_blocks[block.content_hash], block) self._incr_refcount_cached_block( block, self._cached_blocks[block.content_hash]) From 8e192ff967b44b186ea02d30e49fddf656fdfe50 Mon Sep 17 00:00:00 2001 From: Eric Xihui Lin Date: Sat, 25 May 2024 01:00:52 -0400 Subject: [PATCH 111/124] [Kernel][Backend][Model] Blocksparse flash attention kernel and Phi-3-Small model (#4799) Co-authored-by: beagleski Co-authored-by: bapatra Co-authored-by: Barun Patra Co-authored-by: Michael Goin --- csrc/attention/attention_kernels.cu | 185 ++++++-- csrc/cpu/attention.cpp | 37 +- csrc/ops.h | 35 +- docs/source/models/supported_models.rst | 4 + tests/kernels/test_blocksparse_attention.py | 442 +++++++++++++++++ vllm/_custom_ops.py | 30 +- vllm/attention/backends/abstract.py | 1 + vllm/attention/backends/blocksparse_attn.py | 410 ++++++++++++++++ vllm/attention/backends/flash_attn.py | 5 +- vllm/attention/backends/rocm_flash_attn.py | 5 +- vllm/attention/backends/torch_sdpa.py | 5 +- vllm/attention/backends/xformers.py | 5 +- vllm/attention/layer.py | 10 +- .../ops/blocksparse_attention/__init__.py | 0 .../blocksparse_attention_kernel.py | 423 +++++++++++++++++ .../ops/blocksparse_attention/interface.py | 238 ++++++++++ .../ops/blocksparse_attention/utils.py | 216 +++++++++ vllm/attention/ops/paged_attn.py | 25 +- vllm/attention/selector.py | 7 + vllm/entrypoints/openai/serving_engine.py | 1 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/phi3_small.py | 447 ++++++++++++++++++ vllm/transformers_utils/config.py | 2 +- 23 files changed, 2446 insertions(+), 88 deletions(-) create mode 100644 tests/kernels/test_blocksparse_attention.py create mode 100644 vllm/attention/backends/blocksparse_attn.py create mode 100644 vllm/attention/ops/blocksparse_attention/__init__.py create mode 100644 vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py create mode 100644 vllm/attention/ops/blocksparse_attention/interface.py create mode 100644 vllm/attention/ops/blocksparse_attention/utils.py create mode 100644 vllm/model_executor/models/phi3_small.py diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index d6203174e7275..45edc3252380c 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -85,6 +85,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Grid: (num_heads, num_seqs, max_num_partitions). template // Zero means no partitioning. __device__ void paged_attention_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -104,7 +105,9 @@ __device__ void paged_attention_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale) { + const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; @@ -202,11 +205,55 @@ __device__ void paged_attention_kernel( // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + // blocksparse specific vars + int bs_block_offset; + int q_bs_block_id; + if constexpr (IS_BLOCK_SPARSE) { + // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, + // blocksparse_block_size); + q_bs_block_id = (seq_len - 1) / blocksparse_block_size; + if (blocksparse_head_sliding_step >= 0) + // sliding on q heads + bs_block_offset = + (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; + else + // sliding on kv heads + bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * + (-blocksparse_head_sliding_step) + + 1; + } + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { // NOTE(woosuk): The block number is stored in int32. However, we cast it to // int64 because int32 can lead to overflow when this variable is multiplied // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + const bool is_remote = + ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); + const bool is_local = + (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); + if (!is_remote && !is_local) { + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + + if (thread_group_offset == 0) { + // NOTE(linxihui): assign very large number to skipped tokens to + // avoid contribution to the sumexp softmax normalizer. This will + // not be used at computing sum(softmax*v) as the blocks will be + // skipped. + logits[token_idx - start_token_idx] = -FLT_MAX; + } + } + continue; + } + } const int64_t physical_block_number = static_cast(block_table[block_idx]); @@ -335,6 +382,15 @@ __device__ void paged_attention_kernel( // NOTE(woosuk): The block number is stored in int32. However, we cast it to // int64 because int32 can lead to overflow when this variable is multiplied // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && + !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { + continue; + } + } const int64_t physical_block_number = static_cast(block_table[block_idx]); const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; @@ -441,8 +497,8 @@ __device__ void paged_attention_kernel( // Grid: (num_heads, num_seqs, 1). template + int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, + bool IS_BLOCK_SPARSE> __global__ void paged_attention_v1_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -457,18 +513,23 @@ __global__ void paged_attention_v1_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale) { + const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { paged_attention_kernel( + KV_DTYPE, IS_BLOCK_SPARSE>( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, - kv_head_stride, kv_scale); + kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); } // Grid: (num_heads, num_seqs, max_num_partitions). template __global__ void paged_attention_v2_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -488,12 +549,16 @@ __global__ void paged_attention_v2_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale) { + const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { paged_attention_kernel( + KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, - kv_block_stride, kv_head_stride, kv_scale); + kv_block_stride, kv_head_stride, kv_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); } // Grid: (num_heads, num_seqs). @@ -607,25 +672,32 @@ __global__ void paged_attention_v2_reduce_kernel( #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel< \ - T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>), \ + ((void*)vllm::paged_attention_v1_kernel), \ shared_mem_size); \ vllm::paged_attention_v1_kernel \ + NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \ <<>>( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - kv_scale); + kv_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); // TODO(woosuk): Tune NUM_THREADS. template + vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE, + int NUM_THREADS = 128> void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, float kv_scale) { + const c10::optional& alibi_slopes, float kv_scale, + const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -691,23 +763,36 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ - paged_attention_v1_launcher( \ +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v1_launcher( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes, kv_scale); + seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); + +#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + switch (is_block_sparse) { \ + case true: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + break; \ + case false: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + break; \ + } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ switch (block_size) { \ case 8: \ - CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ break; \ case 16: \ - CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ break; \ case 32: \ - CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ @@ -727,18 +812,26 @@ void paged_attention_v1( torch::Tensor& seq_lens, // [num_seqs] int block_size, int max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale){ + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V1_LAUNCHER_BLOCK_SIZE) +} - DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, - CALL_V1_LAUNCHER_BLOCK_SIZE)} #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ vllm::paged_attention_v2_kernel \ + NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \ + PARTITION_SIZE> \ <<>>( \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, kv_scale); \ + kv_block_stride, kv_head_stride, kv_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel \ <<>>( \ @@ -746,14 +839,17 @@ void paged_attention_v1( max_num_partitions); template + vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE, + int NUM_THREADS = 128, int PARTITION_SIZE = 512> void paged_attention_v2_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, float kv_scale) { + const c10::optional& alibi_slopes, float kv_scale, + const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -824,24 +920,36 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ - paged_attention_v2_launcher( \ +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v2_launcher( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ - kv_scale); + kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); + +#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + switch (is_block_sparse) { \ + case true: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + break; \ + case false: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + break; \ + } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ switch (block_size) { \ case 8: \ - CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ break; \ case 16: \ - CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ break; \ case 32: \ - CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ @@ -865,7 +973,10 @@ void paged_attention_v2( torch::Tensor& seq_lens, // [num_seqs] int block_size, int max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale) { + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) } @@ -873,4 +984,4 @@ void paged_attention_v2( #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP +#undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 54df69b7379d6..438e9bdb19f50 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -415,14 +415,17 @@ void paged_attention_v1_impl_launcher( } } // namespace -void paged_attention_v1(torch::Tensor& out, torch::Tensor& query, - torch::Tensor& key_cache, torch::Tensor& value_cache, - int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, - int block_size, int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale) { +void paged_attention_v1( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); + TORCH_CHECK(blocksparse_vert_stride <= 1, + "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) @@ -726,16 +729,18 @@ void paged_attention_v2_impl_launcher( } } // namespace -void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums, - torch::Tensor& max_logits, torch::Tensor& tmp_out, - torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, - float scale, torch::Tensor& block_tables, - torch::Tensor& seq_lens, int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale) { +void paged_attention_v2( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); + TORCH_CHECK(blocksparse_vert_stride <= 1, + "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) diff --git a/csrc/ops.h b/csrc/ops.h index b839eaf0d26c8..567d9fae4bd2a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -2,23 +2,24 @@ #include -void paged_attention_v1(torch::Tensor& out, torch::Tensor& query, - torch::Tensor& key_cache, torch::Tensor& value_cache, - int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, - int block_size, int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale); - -void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums, - torch::Tensor& max_logits, torch::Tensor& tmp_out, - torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, - float scale, torch::Tensor& block_tables, - torch::Tensor& seq_lens, int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale); +void paged_attention_v1( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step); + +void paged_attention_v2( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step); void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, float epsilon); diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 31d4b53bd4409..e4bae80343a2c 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -123,6 +123,10 @@ Alongside each architecture, we include some popular models that use it. - Phi-3 - :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc. - + * - :code:`Phi3SmallForCausalLM` + - Phi-3-Small + - :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc. + - * - :code:`QWenLMHeadModel` - Qwen - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py new file mode 100644 index 0000000000000..9da13ca6e2310 --- /dev/null +++ b/tests/kernels/test_blocksparse_attention.py @@ -0,0 +1,442 @@ +import random +from typing import List, Optional, Tuple + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.attention.ops.blocksparse_attention.interface import ( + LocalStridedBlockSparseAttn) +from vllm.utils import get_max_shared_memory_bytes, is_hip + +from .allclose_default import get_default_atol, get_default_rtol + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +# This will change depending on the compute capability. +# - 512 as a buffer +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +# MAX_SEQ_LEN = 2771 + +# There may not be enough gpu memory due to large NUM_BLOCKS. +# Reduce NUM_BLOCKS when it happens. +NUM_BLOCKS = 4321 # Arbitrary values for testing +PARTITION_SIZE = 512 +DTYPES = [torch.half, torch.bfloat16] +NUM_GEN_SEQS = [3] # Arbitrary values for testing +NUM_PREFILL_SEQS = [3] # Arbitrary values for testing +NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing + +HEAD_SIZES = [64, 112] +BLOCK_SIZES = [16, 32] +USE_ALIBI = [False, True] +KV_CACHE_DTYPE = ["auto", "fp8"] +SEEDS = [0] +CUDA_DEVICES = ['cuda:0'] +BLOCKSPARSE_LOCAL_BLOCKS = [16] +BLOCKSPARSE_VERT_STRIDES = [8] + +BLOCKSPARSE_BLOCK_SIZES = [64] +BLOCKSPARSE_HEADS_SLIDINGS = [0, 2, -1] +BLOCKSPARSE_HOMO_HEADS = [True, False] + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + +def ref_single_query_cached_kv_attention( + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 1, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[1] + head_size = value_cache.shape[2] + block_size = value_cache.shape[3] + num_seqs = query.shape[0] + + block_tables = block_tables.cpu().tolist() + seq_lens = seq_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables[i] + seq_len = int(seq_lens[i]) + + keys = [] + values = [] + for j in range(seq_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) + keys.append(k) + + v = value_cache[block_number, :, :, block_offset] + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + alibi_bias = None + if alibi_slopes is not None: + # Create the ALiBi bias used in the paged attention kernel. + position_ids = torch.arange(seq_len).int() + alibi_bias = (position_ids - seq_len + 1).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( + 1, 1, -1) + + if blocksparse_vert_stride >= 1: + bsize = blocksparse_block_size + hsliding = blocksparse_head_sliding_step + vert = blocksparse_vert_stride + locals = blocksparse_local_blocks + qb = (seq_len - 1) // bsize + attn_mask = q.new_zeros( + (num_query_heads, 1, seq_len)).float() - torch.inf + for h in range(num_query_heads): + if hsliding >= 0: # slide with q heads + bs_offset = (tp_rank * num_query_heads + h) * hsliding + 1 + else: # slide with kv heads + bs_offset = (tp_rank * num_kv_heads + + h // num_queries_per_kv) * (-hsliding) + 1 + for kb in range(qb + 1): + kj = kb * bsize + if (qb - kb) < locals or \ + (kb + bs_offset) % vert == 0: + attn_mask[h, 0, kj:min(kj + bsize, seq_len)] = 0 + if alibi_bias is not None: + attn_mask += alibi_bias + else: + attn_mask = alibi_bias + + out = ref_masked_attention(q, keys, values, scale, attn_mask=attn_mask) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +@pytest.mark.parametrize("version", ["v1", "v2"]) +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS) +@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES) +@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES) +@pytest.mark.parametrize("blocksparse_head_sliding_step", + BLOCKSPARSE_HEADS_SLIDINGS) +def test_paged_attention( + kv_cache_factory, + version: str, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + seed: int, + device: str, + blocksparse_local_blocks: int, + blocksparse_vert_stride: int, + blocksparse_block_size: int, + blocksparse_head_sliding_step: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.rand(num_query_heads, dtype=torch.float) + + seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int) + + # Create the block tables. + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, + num_kv_heads, head_size, + kv_cache_dtype, dtype, seed, + device) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Using default kv_scale + kv_scale = 1.0 + tp_rank = 0 + + # Call the paged attention kernel. + output = torch.empty_like(query) + if version == "v1": + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + tp_rank=tp_rank, + blocksparse_local_blocks=blocksparse_local_blocks, + blocksparse_vert_stride=blocksparse_vert_stride, + blocksparse_block_size=blocksparse_block_size, + blocksparse_head_sliding_step=blocksparse_head_sliding_step, + ) + elif version == "v2": + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + assert PARTITION_SIZE % block_size == 0 + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + tp_rank=tp_rank, + blocksparse_local_blocks=blocksparse_local_blocks, + blocksparse_vert_stride=blocksparse_vert_stride, + blocksparse_block_size=blocksparse_block_size, + blocksparse_head_sliding_step=blocksparse_head_sliding_step, + ) + else: + raise AssertionError(f"Unknown version: {version}") + + # Run the reference implementation. + if kv_cache_dtype == "fp8": + # Convert cache data back to dtype. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, + block_size, x) + dequantized_key_cache = torch.empty(size=key_cache_shape, + dtype=dtype, + device=device) + ops.convert_fp8(dequantized_key_cache, key_cache) + key_cache = dequantized_key_cache + + value_cache_shape = value_cache.shape + dequantized_value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device=device) + ops.convert_fp8(dequantized_value_cache, value_cache) + value_cache = dequantized_value_cache + + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + seq_lens, + scale, + alibi_slopes, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. + atol = get_default_atol(output) if is_hip() else 1e-3 + rtol = get_default_rtol(output) if is_hip() else 1e-5 + + # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, + # so we use a relaxed tolerance for the test. + atol, rtol = 1e-3, 1e-5 + if kv_cache_dtype == "fp8": + atol, rtol = 1e-2, 1e-5 + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) + + +def ref_multi_query_kv_attention( + cu_seq_lens: List[int], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + ref_outputs = [] + for i in range(num_seqs): + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + seq_len = end_idx - start_idx + + # Create attention mask. + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype) + + ref_output = ref_masked_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS) +@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES) +@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES) +@pytest.mark.parametrize("blocksparse_homo_heads", BLOCKSPARSE_HOMO_HEADS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_varlen_blocksparse_attention_prefill( + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + blocksparse_local_blocks: int, + blocksparse_vert_stride: int, + blocksparse_block_size: int, + blocksparse_homo_heads: bool, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. + # As the xformers library is already tested with its own tests, we can use + # a smaller MAX_SEQ_LEN here. + max_len = min(MAX_SEQ_LEN, 4096) + seq_lens = random.sample(range(1, max_len), num_seqs) + cu_seq_lens = torch.cumsum(torch.tensor([0] + seq_lens), dim=0) + num_tokens = sum(seq_lens) + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + + qkv = torch.empty(num_tokens, + num_query_heads + 2 * num_kv_heads, + head_size, + dtype=dtype) + qkv.uniform_(-scale, scale) + query, key, value = qkv.split( + [num_query_heads, num_kv_heads, num_kv_heads], dim=1) + + bs_attn_op = LocalStridedBlockSparseAttn( + num_query_heads, + max_len, + local_blocks=blocksparse_local_blocks, + vert_stride=blocksparse_vert_stride, + block_size=blocksparse_block_size, + device=device, + dtype=dtype, + homo_head=blocksparse_homo_heads) + + output = bs_attn_op(query, + key, + value, + cu_seq_lens.to(device), + sm_scale=scale) + + if num_queries_per_kv > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) + value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) + + ref_output = ref_multi_query_kv_attention( + cu_seq_lens, + query, + key, + value, + scale, + dtype, + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f0fab4d8aa26d..22cf5a44e341f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -45,11 +45,17 @@ def paged_attention_v1( alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, kv_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, ) -> None: - vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, - num_kv_heads, scale, block_tables, seq_lens, - block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, kv_scale) + vllm_ops.paged_attention_v1( + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, + kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step) def paged_attention_v2( @@ -69,12 +75,18 @@ def paged_attention_v2( alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, kv_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, ) -> None: - vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, - key_cache, value_cache, num_kv_heads, scale, - block_tables, seq_lens, block_size, - max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale) + vllm_ops.paged_attention_v2( + out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, + alibi_slopes, kv_cache_dtype, kv_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step) # pos encoding ops diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 94ab64de30a94..6396103bf5efa 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -111,6 +111,7 @@ def __init__( alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py new file mode 100644 index 0000000000000..dce2b83615b7a --- /dev/null +++ b/vllm/attention/backends/blocksparse_attn.py @@ -0,0 +1,410 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +from vllm.attention.ops.blocksparse_attention.interface import ( + LocalStridedBlockSparseAttn, get_head_sliding_step) +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) + + +@dataclass +class BlocksparseParams: + max_seqlen: int + + # Num q heads per tensor-parallel rank/partition + num_heads: int # per TP partition + # Num kv heads per tensor-parallel rank/partition + num_kv_heads: int + + # block size used for blocksparse attention. + # This is the block_size used in `local_blocks`, `vert_stride`. + block_size: int + + # Number of blocks for local attention, i.e., number of + # local attended tokens / `sparse_block_size` + local_blocks: int + + # Attend to one block per every `vert_stride` blocks. + # Controlling the sparsity + vert_stride: int + """ + If to use the same vertical stride offset for all heads, + i.e., attend to the same block of tokens on all heads. + By default, it is False, i.e., attention on the non-local + blocks depends on the `head_idx`, that is on + blocks satisfying + `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0` + where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`, + `block_idx = position_id // sparse_block_size`. + See `..ops.blocksparse_attention.utils:get_sparse_attn_mask` + for more detail. + """ + homo_head: bool = False + + # If within a group, the kv offsets that each q attends is the same or no. + homo_head_group: bool = False + + # Decided by homo_head and homo_head group + head_sliding_step: int = field(init=False) + + # range of q heads to for a TP rank + active_head_range: Tuple = field(init=False) + + def __post_init__(self): + assert self.block_size > 0 + assert self.local_blocks >= 0 + assert self.vert_stride >= 1 + assert self.num_heads % self.num_kv_heads == 0 + + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + total_heads = tp_size * self.num_heads + total_kv_heads = tp_size * self.num_kv_heads + + if self.homo_head: + self.head_sliding_step = 0 + elif self.homo_head_group: + head_sliding_step = get_head_sliding_step(total_kv_heads, + self.vert_stride) + # negative indicates sliding along kv heads, i.e., homo q group + self.head_sliding_step = -head_sliding_step + else: + self.head_sliding_step = get_head_sliding_step( + total_heads, self.vert_stride) + + self.active_head_range = ( + tp_rank * self.num_heads, + (tp_rank + 1) * self.num_heads, + ) + + +class BlocksparseFlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: + return BlocksparseFlashAttentionImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata": + return BlocksparseFlashAttentionMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class BlocksparseFlashAttentionMetadata(AttentionMetadata): + """A copy of Metadata for FlashAttentionBackend, + to avoid having to install flash_attn. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + _cached_prefill_metadata: Optional[ + "BlocksparseFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional[ + "BlocksparseFlashAttentionMetadata"] = None + + @property + def prefill_metadata( + self) -> Optional["BlocksparseFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = BlocksparseFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class BlocksparseFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + ) -> None: + assert blocksparse_params is not None + assert alibi_slopes is None, ValueError( + "Alibi not support for blocksparse flash attention.") + assert sliding_window is None, ValueError( + "sliding_window is invalid for blocksparse attention.") + + if "num_heads" not in blocksparse_params: + blocksparse_params["num_heads"] = num_heads + if "num_kv_heads" not in blocksparse_params: + blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads + self.blocksparse_params = BlocksparseParams(**blocksparse_params) + self.kv_cache_dtype = kv_cache_dtype + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.alibi_slopes = alibi_slopes + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.local_blocks = self.blocksparse_params.local_blocks + self.vert_stride = self.blocksparse_params.vert_stride + self.sparse_block_size = self.blocksparse_params.block_size + self.head_sliding_step = self.blocksparse_params.head_sliding_step + + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + total_num_heads = num_heads * self.tp_size + self.bs_attn = LocalStridedBlockSparseAttn( + total_num_heads, + self.blocksparse_params.max_seqlen, + self.blocksparse_params.local_blocks, + self.blocksparse_params.vert_stride, + self.blocksparse_params.block_size, + homo_head=self.blocksparse_params.homo_head, + active_head_range=self.blocksparse_params.active_head_range, + ) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: BlocksparseFlashAttentionMetadata, + kv_scale: float = 1.0, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache is not None: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + kv_scale, + ) + + if prefill_meta := attn_metadata.prefill_metadata: + + # Prompt run. + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + + assert kv_cache is None \ + or prefill_meta.block_tables is None \ + or prefill_meta.block_tables.numel() == 0, \ + "Does not support prefix-enabled attention." + + output = self.bs_attn( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + sm_scale=self.scale, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + self.blocksparse_params.max_seqlen, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + kv_scale, + tp_rank=self.tp_rank, + blocksparse_local_blocks=self.local_blocks, + blocksparse_vert_stride=self.vert_stride, + blocksparse_block_size=self.sparse_block_size, + blocksparse_head_sliding_step=self.head_sliding_step, + ) + + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 0f4568070cfc4..0b9d6283493f2 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,6 +1,6 @@ """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache @@ -219,7 +219,10 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: + assert blocksparse_params is None, ValueError( + "FlashAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 94f3f55636ed6..e92e6c5e2dc8d 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,6 +1,6 @@ """Attention layer ROCm GPUs.""" from dataclasses import dataclass -from typing import List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch @@ -201,7 +201,10 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: + assert blocksparse_params is None, ValueError( + "ROCFlashAttention does not support blocksparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index a19c97e1e0e35..9b50adec5244d 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -1,7 +1,7 @@ """ Attention layer with torch scaled_dot_product_attention and PagedAttention.""" from dataclasses import dataclass -from typing import List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch from torch.nn.functional import scaled_dot_product_attention @@ -100,7 +100,10 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: + assert blocksparse_params is None, ValueError( + "Torch SPDA does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 96169da6cf92c..99a3e88bc07b6 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -1,6 +1,6 @@ """Attention layer with xFormers and PagedAttention.""" from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch from xformers import ops as xops @@ -212,7 +212,10 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: + assert blocksparse_params is None, ValueError( + "XFormer does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index dc7b3940bc9b7..b67f04c51d493 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -1,5 +1,5 @@ """Attention layer.""" -from typing import List, Optional +from typing import Any, Dict, List, Optional import torch import torch.nn as nn @@ -33,6 +33,7 @@ def __init__( sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() if cache_config is not None: @@ -69,10 +70,12 @@ def __init__( dtype = torch.get_default_dtype() attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads, sliding_window, dtype, kv_cache_dtype, - block_size) + block_size, blocksparse_params + is not None) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype) + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params) def forward( self, @@ -90,4 +93,5 @@ def extra_repr(self) -> str: s += f", num_heads={self.impl.num_heads}" # type: ignore s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore s += f", scale={self.impl.scale}" # type: ignore + s += f", backend={self.impl.__class__.__name__}" return s diff --git a/vllm/attention/ops/blocksparse_attention/__init__.py b/vllm/attention/ops/blocksparse_attention/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py new file mode 100644 index 0000000000000..ec1c37c5bcb0e --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py @@ -0,0 +1,423 @@ +import torch +import triton +import triton.language as tl + + +def blocksparse_flash_attn_varlen_fwd( + q, + k, + v, # (#tokens, n_heads, head_size) + cu_seqlens_k, + cu_seqlens_q, + sm_scale, + sparse_layout, + *, + block_size=64, + q_block_size=None, + max_seqlen=None): + # split q to blocks + + assert isinstance(sparse_layout, (list, tuple)) + + _, n_heads, head_size = q.shape + batch_size = cu_seqlens_k.size(0) - 1 + q_block_size = q_block_size or block_size + + assert q.dim() == k.dim() == v.dim() == 3 + assert q.size(1) % k.size(1) == 0 + assert q.size(2) == k.size(2) + # TODO(linxihui): allow k, v to have different head_size + assert k.shape == v.shape + assert cu_seqlens_k.dim() == 1 + + q_k_ratio = q.size(1) // k.size(1) + + if cu_seqlens_q is None: + if q.size(0) == batch_size: # decoding only + cu_seqlens_q = torch.arange( + 0, + batch_size + 1, + dtype=cu_seqlens_k.dtype, + device=cu_seqlens_k.device, + ) + elif q.size(0) == k.size(0): + cu_seqlens_q = cu_seqlens_k + else: + raise ValueError("cu_seqlens_q must be specified\ + if it mix of prefilling and decoding.") + else: + assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0) + + # switch to use cpu to avoid too many kernel launches when iterated over + q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu() + k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu() + + assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), ( + "length of q should either be 1 (decoding) or same as k (prefilling).") + + if max_seqlen: + assert k_lens.max() <= max_seqlen + + n_blocks = (q_lens + q_block_size - 1) // q_block_size + + q_batch_ids = torch.tensor( + [i for i, n in enumerate(n_blocks) for _ in range(n)], + dtype=cu_seqlens_q.dtype, + device=cu_seqlens_q.device, + ) + q_start_sids = torch.tensor( + [i * q_block_size for n in n_blocks for i in range(n)], + dtype=cu_seqlens_q.dtype, + device=cu_seqlens_q.device, + ) + + out = q.new_empty(q.shape) + cu_seqlens_q = cu_seqlens_q.contiguous() + cu_seqlens_k = cu_seqlens_k.contiguous() + + layout_crow_indices, layout_col_indices = sparse_layout + block_d = triton.next_power_of_2(head_size) + + decoding_only = (q_lens == 1).all().item() + grid = (len(q_start_sids), n_heads, 1) + + _fwd_kernel_batch_inference[grid]( + q, + k, + v, + out, + sm_scale, + cu_seqlens_q[:-1], + cu_seqlens_q[1:], + cu_seqlens_k[:-1], + cu_seqlens_k[1:], + q_batch_ids, + q_start_sids, + 0, + *q.stride(), + 0, + *k.stride(), + 0, + *v.stride(), + 0, + *out.stride(), + layout_crow_indices, + layout_col_indices, + *layout_crow_indices.stride(), + *layout_col_indices.stride(), + q_k_ratio, + HAS_BATCH_DIM=False, + D_HEAD=head_size, + BLOCK_M=q_block_size, + BLOCK_N=block_size, + BLOCK_D=block_d, + BLOCK_M_LOADING=(16 if decoding_only else + q_block_size), # smaller for decoding + EVEN_D=block_d == head_size, + num_warps=1 if decoding_only else 4, + num_stages=3) + + return out + + +@triton.jit +def _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_col_idx, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + LAST_K_BLOCK: tl.constexpr, + BLOCK_M_LOADING: tl.constexpr, + BLOCK_N: tl.constexpr, + D_HEAD: tl.constexpr, + EVEN_D: tl.constexpr, + M_LT_N: tl.constexpr, +): + k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + + k_block_col_idx * layout_col_stride_m).to(tl.int32) + start_n = k_block_id * BLOCK_N + if LAST_K_BLOCK: + if EVEN_D: + k = tl.load( + k_ptrs + start_n * stride_kt, + mask=offs_n[None, :] + start_n < k_seqlen, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kt, + mask=(offs_n[None, :] + start_n < k_seqlen) & + (offs_d[:, None] < D_HEAD), + ) + else: + if EVEN_D: + k = tl.load(k_ptrs + start_n * stride_kt) + else: + k = tl.load(k_ptrs + start_n * stride_kt, + mask=offs_d[:, None] < D_HEAD) + + qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N + if LAST_K_BLOCK | M_LT_N: + qk += tl.where( + offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), + 0, + float("-inf"), + ) + + # flash-attn2 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + # update m_i + m_i = m_ij + l_i = l_i * alpha + l_ij + + p = p.to(Q.dtype.element_ty) + # update acc + if LAST_K_BLOCK: + if EVEN_D: + v = tl.load( + v_ptrs + start_n * stride_vt, + mask=offs_n[:, None] + start_n < k_seqlen, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vt, + mask=(offs_n[:, None] + start_n < k_seqlen) & + (offs_d[None, :] < D_HEAD), + ) + else: + if EVEN_D: + v = tl.load(v_ptrs + start_n * stride_vt) + else: + v = tl.load(v_ptrs + start_n * stride_vt, + mask=offs_d[None, :] < D_HEAD) + + acc += tl.dot(p, v) + + return acc, l_i, m_i + + +@triton.heuristics({ + "M_LT_N": + lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"], +}) +@triton.jit +def _fwd_kernel_batch_inference( + Q, + K, + V, + Out, + sm_scale, + q_batch_starts, + q_batch_ends, + k_batch_starts, + k_batch_ends, + q_batch_ids, + q_start_sids, + stride_qb, + stride_qt, + stride_qh, + stride_qd, + stride_kb, + stride_kt, + stride_kh, + stride_kd, + stride_vb, + stride_vt, + stride_vh, + stride_vd, + stride_ob, + stride_ot, + stride_oh, + stride_od, + layout_crow_ptr, + layout_col_ptr, + layout_crow_stride_h, + layout_crow_stride_m, + layout_col_stride_h, + layout_col_stride_m, + q_k_ratio, + HAS_BATCH_DIM: tl.constexpr, + D_HEAD: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_M_LOADING: tl.constexpr, + EVEN_D: tl.constexpr, + M_LT_N: tl.constexpr, +): + """ + NOTATION: + pid: position id + sid: storage id + sbid: storage block id + pbid: position block id + offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col) + + TODO(linxihui): + Optimize grouped-attn + """ + off_zm = tl.program_id(0) + off_h = tl.program_id(1) + + off_h_for_kv = off_h // q_k_ratio + + if HAS_BATCH_DIM: + off_z = tl.program_id(2) + Q += off_z * stride_qb + K += off_z * stride_kb + V += off_z * stride_vb + Out += off_z * stride_ob + start_m = off_zm + q_start_sid = start_m * BLOCK_M # always 0 for decoding + else: + off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1] + q_start_sid = tl.load(q_start_sids + off_zm) + start_m = q_start_sid // BLOCK_M # q_sbid + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32) + q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start + k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32) + k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start + past_len = k_seqlen - q_seqlen + + Q += q_cu_start * stride_qt + off_h * stride_qh + K += k_cu_start * stride_kt + off_h_for_kv * stride_kh + V += k_cu_start * stride_vt + off_h_for_kv * stride_vh + Out += q_cu_start * stride_ot + off_h * stride_oh + + q_pbid = (past_len + q_start_sid) // BLOCK_M + + if EVEN_D: + q = tl.load( + Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, + mask=offs_m[:, None] < q_seqlen, + ) + else: + q = tl.load( + Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, + mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), + other=0, + ) + + sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h + + q_pbid * layout_crow_stride_m) + + # TODO(linxihui): load at once, with any Triton version + # that supports `tl.split`, e.g., Triton 3.0 + k_block_start = tl.load(sparse_crow_ptr).to(tl.int32) + k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32) + + m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) + acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32) + + k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd + + sm_scale *= ( + 1.44269504 # 1/log2 as we use base2 for exponential and logarithm + ) + + for k_block_col_idx in range(k_block_start, k_block_end - 1): + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_col_idx, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + False, + BLOCK_M_LOADING, + BLOCK_N, + D_HEAD, + EVEN_D, + M_LT_N, + ) + + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_end - 1, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + True, + BLOCK_M_LOADING, + BLOCK_N, + D_HEAD, + EVEN_D, + M_LT_N, + ) + + # flash-attn 2 + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + + # write output + if EVEN_D: + tl.store( + Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, + acc, + mask=offs_m[:, None] < q_seqlen, + ) + else: + tl.store( + Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, + acc, + mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), + ) diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py new file mode 100644 index 0000000000000..300211e70bb79 --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/interface.py @@ -0,0 +1,238 @@ +import math + +import torch + +from vllm.utils import is_cpu, is_hip + +from .utils import (dense_to_crow_col, get_head_sliding_step, + get_sparse_attn_mask) + +IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 8) + +if IS_COMPUTE_8_OR_ABOVE: + from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd + + +class LocalStridedBlockSparseAttn(torch.nn.Module): + + def __init__( + self, + n_heads, + max_seqlen, + local_blocks, + vert_stride, + block_size, + device=None, + dtype=None, + homo_head=False, + active_head_range=None, + q_block_size=None, + use_spda=None, + ): + super().__init__() + if use_spda is None: + use_spda = is_hip() or is_cpu() or not \ + IS_COMPUTE_8_OR_ABOVE + device = device or (torch.cuda.current_device() + if torch.cuda.is_available() else "cpu") + device = torch.device(device) + # NOTE: vllm CPU backend support BF16 instead of FP16. + dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE + or device.type == "cpu" else torch.half) + + self.n_heads = n_heads + self.max_seqlen = max_seqlen + self.local_blocks = local_blocks + self.vert_stride = vert_stride + self.use_spda = use_spda + self.dtype = dtype + self.device = device + self.block_size = block_size + self.q_block_size = q_block_size + self.homo_head = homo_head + self.active_head_range = active_head_range + self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride, + homo_head) + + sparse_layout, sparse_pattern, self.dense_attn_mask = ( + self.get_attn_pattern(dtype, device)) + + if q_block_size is not None and q_block_size != block_size: + if q_block_size > block_size: + assert q_block_size % block_size == 0 + blocks_to_merge = q_block_size // block_size + shape = sparse_pattern.shape + sparse_pattern = sparse_pattern.view(shape[0], -1, + blocks_to_merge, + shape[-1]) + sparse_pattern = sparse_pattern.sum(2) + sparse_layout = dense_to_crow_col(sparse_pattern) + else: + raise ValueError( + "Does not support smaller q_block_size. It will be slower." + ) + + self.sparse_layout = sparse_layout + + def get_attn_pattern(self, dtype, device): + sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask( + self.n_heads, + self.max_seqlen, + self.max_seqlen, + dtype, + device, + block_size=self.block_size, + local_blocks=self.local_blocks, + vert_stride=self.vert_stride, + homo_head=self.homo_head, + return_dense=self.use_spda, + dense_mask_type="bias", + ) + if (not self.homo_head) and (self.active_head_range is not None): + assert isinstance(self.active_head_range, tuple) + assert (len(self.active_head_range) == 2) + h_start, h_end = self.active_head_range + sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout) + if self.use_spda: + dense_attn_mask = dense_attn_mask[h_start:h_end] + return sparse_layout, sparse_pattern, dense_attn_mask + + def varlen_attn(self, + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=None, + sm_scale=None): + """ + q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). + Support grouped attention, with `q[:, i*r:(i*r + r)]` + is correspondent to `k[:, i]`, where `r` is the q/k ratio. + cu_seqlens_k: shape=(batch_size + 1,), + indicating segment of samples, + e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i + cu_seqlens_q: shape=(batch_size + 1, ). + Default None: same as cu_seqlens_k for prefilling or + [0, 1, .., batch_size] for decoding. + The only case you need to specify is when q is a mix of + prefilling and decoding. + sm_scale: softmax scale, default to 1/sqrt(head_size). + + return: tensor of shape as q. + """ + assert ( + IS_COMPUTE_8_OR_ABOVE + ), "Requires compute capability of 8 or above (Ampere or newer) to use \ + Triton kernel." + + sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) + + return blocksparse_flash_attn_varlen_fwd( + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q, + sm_scale, + self.sparse_layout, + block_size=self.block_size, + q_block_size=self.q_block_size, + max_seqlen=self.max_seqlen, + ) + + @staticmethod + def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1): + """ + :param x: (total_tokens, n_heads, head_size) + :return: (batch, n_heads, length, head_size) + """ + x_padded = x.new_empty( + len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2)) + cu_seqlens = cu_seqlens.cpu() + for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): + x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0, + 1).unsqueeze(1)) + return x_padded.flatten(1, 2) + + @staticmethod + def transpose_and_unpad(x_padded, cu_seqlens): + """ + :param x_padded: (batch, n_heads, length, head_size) + :return: (total_tokens, n_heads, head_size) + """ + cu_seqlens = cu_seqlens.cpu() + total_n_tokens = cu_seqlens[-1] + x = x_padded.new_empty(total_n_tokens, x_padded.size(1), + x_padded.size(3)) + for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): + x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1)) + return x + + def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): + """For CPU, V100 or other older GPUs. + NOTE: torch SPDA supports nested tensor, + but seems extremely slow. Choose to pad instead. + """ + assert (cu_seqlens_q is None or + (cu_seqlens_q + == cu_seqlens_k).all()), "Can only handle prompt with SPDA." + assert q.size(0) == k.size(0), "can only handle prompt with SPDA." + + assert q.size(1) % k.size(1) == 0 + q_k_ratio = q.size(1) // k.size(1) + sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) + cu_seqlens = cu_seqlens_k.cpu() + maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + + if (self.dense_attn_mask.dtype != q.dtype + or self.dense_attn_mask.device != q.device): + _, _, self.dense_attn_mask = self.get_attn_pattern( + q.dtype, q.device) + attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen] + + q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1) + k2, v2 = [ + self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio) + for x in [k, v] + ] + spda_output = torch.nn.functional.scaled_dot_product_attention( + q2, k2, v2, attn_mask=attn_mask, scale=sm_scale) + return self.transpose_and_unpad(spda_output, cu_seqlens) + + def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): + """Dispatch to `varlen_attn` (Ampere or newer) or + `self.spda`(cpu, Volta, Turing or older)based on + the type of device used and cuda compute capability. + + q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). + Support grouped attention, with `q[:, i*r:(i*r + r)]` + is correspondent to `k[:, i]`, where `r` is the q/k ratio. + cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples, + e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i + cu_seqlens_q: shape=(batch_size + 1, ). + Default None: same as cu_seqlens_k for prefilling or + [0, 1, .., batch_size] for decoding. + The only case you need to specify + is when q is a mix of prefilling + and decoding. + sm_scale: softmax scale, default to 1/sqrt(head_size). + + return: tensor of shape as q. + """ + assert k.dim() == 3 + if self.use_spda: + return self.spda( + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=cu_seqlens_q, + sm_scale=sm_scale, + ) + return self.varlen_attn(q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=cu_seqlens_q, + sm_scale=sm_scale) \ No newline at end of file diff --git a/vllm/attention/ops/blocksparse_attention/utils.py b/vllm/attention/ops/blocksparse_attention/utils.py new file mode 100644 index 0000000000000..0d90dd971e156 --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/utils.py @@ -0,0 +1,216 @@ +# Helper functions for 3D sparse pattern +# These function are not optimized and very inefficient. +# Avoid calling them too frequent or use a cache mechanism. + +from functools import lru_cache + +import torch +import triton +from scipy import sparse + + +def dense_to_crow_col(x: torch.Tensor): + """Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing. + NOTE: col_indices padded -1 + """ + device = x.device + pad = -1 + dim = x.dim() + assert x.dim() in (2, 3) + if x.dim() == 2: + x = x[None] + x = [sparse.csr_matrix(xi.bool().cpu().numpy()) for xi in x] + crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x]) + cols = [torch.from_numpy(xi.indices) for xi in x] + max_cols = max(len(xi) for xi in cols) + cols = [ + torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) + for xi in cols + ] + cols = torch.vstack(cols) + if dim == 2: + crows = crows[0] + cols = cols[0] + return crows.to(device), cols.to(device) + + +def crow_col_to_dense(crows: torch.Tensor, + cols: torch.Tensor, + dtype: torch.dtype = torch.float16): + dim = crows.dim() + if dim == 1: + crows = crows[None] + cols = cols[None] + device = crows.device + crows, cols = crows.cpu(), cols.cpu() # faster in cpu + shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1) + x = torch.zeros(shape, dtype=dtype) + for i in range(shape[0]): + for j in range(shape[1]): + x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1 + if dim == 1: + x = x[0] + return x.to(device) + + +def dense_to_ccol_row(x: torch.Tensor): + """Similar, but to CSC format""" + x = x.transpose(-2, -1) + return dense_to_crow_col(x) + + +def ccol_row_to_dense(ccol: torch.Tensor, + rows: torch.Tensor, + dtype: torch.dtype = torch.float16): + return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous() + + +def _get_sparse_attn_mask_homo_head( + q_len: int, + max_seqlen: int, + dtype: torch.dtype, + device: torch.device, + block_size: int = 128, + local_blocks: int = 4, + vert_stride: int = 4, + return_dense: bool = False, +): + """ + :return: a tuple of 3: + - tuple of crow_indices, col_indices representation + of CSR format. + - block dense mask + - all token dense mask (be aware that it can be + OOM if it is too big) if `return_dense==True`, + otherwise, None + """ + with torch.no_grad(): + num_blocks = triton.cdiv(max_seqlen, block_size) + q_pos = torch.arange(num_blocks)[:, None] + k_pos = torch.arange(num_blocks)[None] + mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0 + block_mask_dense = (((q_pos >= k_pos) + & ((q_pos - k_pos < local_blocks) + | mask_vert_strided)).to(device).to(dtype)) + num_blocks_q = triton.cdiv(q_len, block_size) + block_mask_dense_output = (dense_to_crow_col( + block_mask_dense[-num_blocks_q:].contiguous())) + if return_dense: + mask_dense = torch.kron( + block_mask_dense, + block_mask_dense.new_ones((block_size, block_size)), + ) + causal_mask = torch.tril(torch.ones( + max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] + mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask + return ( + block_mask_dense_output, + block_mask_dense, + mask_dense, + ) + else: + return ( + block_mask_dense_output, + block_mask_dense, + None, + ) + + +def binary_mask_to_bias(mask_dense: torch.Tensor): + mask_dense = 1 - mask_dense + mask_dense.masked_fill_(mask_dense.bool(), -torch.inf) + return mask_dense + + +def get_head_sliding_step(n_heads: int, + vert_stride: int, + homo_head: bool = False): + if homo_head: + return 0 + return max(1, int(vert_stride / n_heads)) + + +@lru_cache +def get_sparse_attn_mask( + n_heads: int, + q_len: int, + max_seqlen: int, + dtype: torch.dtype, + device: torch.device, + block_size: int = 64, + local_blocks: int = 4, + vert_stride: int = 4, + homo_head: bool = True, + return_dense: bool = False, + dense_mask_type: str = "binary", +): + """ + :param dense_mask_type: "binary" (0 for skip token, 1 for others) + or "bias" (-inf for skip token, 0 or others) + :return: a tuple of 3: + - tuple of crow_indices, col_indices representation + of CSR format. + - block dense mask + - all token dense mask (be aware that it can be OOM if it + is too big) if `return_dense==True`, otherwise, None + """ + assert dense_mask_type in ("binary", "bias") + if homo_head: + with torch.no_grad(): + (crow, col), block_mask_dense, mask_dense = ( + _get_sparse_attn_mask_homo_head( + q_len, + max_seqlen, + dtype, + device, + block_size, + local_blocks, + vert_stride, + return_dense, + )) + crow = crow[None].expand(n_heads, crow.shape[0]) + col = col[None].expand(n_heads, col.shape[0]) + if return_dense: + mask_dense = mask_dense[None].expand(n_heads, + *mask_dense.shape) + if dense_mask_type == "bias": + mask_dense = binary_mask_to_bias(mask_dense) + return (crow, col), block_mask_dense, mask_dense + + with torch.no_grad(): + num_blocks = triton.cdiv(max_seqlen, block_size) + q_pos = torch.arange(num_blocks)[None, :, None] + k_pos = torch.arange(num_blocks)[None, None] + head_sliding_step = get_head_sliding_step(n_heads, vert_stride) + mask_vert_strided = [ + (torch.arange(num_blocks) + h * head_sliding_step + 1) % + vert_stride == 0 for h in range(n_heads) + ] + mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1) + block_mask_dense = (((q_pos >= k_pos) + & ((q_pos - k_pos < local_blocks) + | mask_vert_strided)).to(device).to(dtype)) + num_blocks_q = triton.cdiv(q_len, block_size) + block_mask_dense_output = block_mask_dense[:, -num_blocks_q:] + if return_dense: + mask_dense = torch.kron( + block_mask_dense, + block_mask_dense.new_ones((block_size, block_size)), + ) + causal_mask = torch.tril(torch.ones( + max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] + mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None] + if dense_mask_type == "bias": + mask_dense = binary_mask_to_bias(mask_dense) + + return ( + dense_to_crow_col(block_mask_dense_output), + block_mask_dense, + mask_dense, + ) + else: + return ( + dense_to_crow_col(block_mask_dense_output), + block_mask_dense, + None, + ) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 30feaa4da254d..e119fdcf11113 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -91,9 +91,21 @@ def forward_decode( scale: float, alibi_slopes: Optional[torch.Tensor], kv_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, ) -> torch.Tensor: - output = torch.empty_like(query) + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + output = torch.empty_like(query) block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // @@ -107,6 +119,7 @@ def forward_decode( # For context len > 8192, use V2 kernel to avoid shared memory shortage. use_v1 = (max_seq_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)) + if use_v1: # Run PagedAttention V1. ops.paged_attention_v1( @@ -123,6 +136,11 @@ def forward_decode( alibi_slopes, kv_cache_dtype, kv_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, ) else: # Run PagedAttention V2. @@ -155,6 +173,11 @@ def forward_decode( alibi_slopes, kv_cache_dtype, kv_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, ) return output diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index f191461dcd3b7..9ceda3431b898 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -29,7 +29,14 @@ def get_attn_backend( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, + is_blocksparse: bool = False, ) -> Type[AttentionBackend]: + + if is_blocksparse: + logger.info("Using BlocksparseFlashAttention backend.") + from vllm.attention.backends.blocksparse_attn import ( + BlocksparseFlashAttentionBackend) + return BlocksparseFlashAttentionBackend """Determine which attention backend to use and only import the selected backend module. """ diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index db3fc85decd70..0df0223b9dbb2 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -100,6 +100,7 @@ def _create_logprobs( token_logprob = step_top_logprobs[token_id].logprob token = step_top_logprobs[token_id].decoded_token logprobs.tokens.append(token) + token_logprob = max(token_logprob, -9999.0) logprobs.token_logprobs.append(token_logprob) if num_output_top_logprobs: diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 6aec104be8da4..a92abe6b5b8dc 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -56,6 +56,7 @@ "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), + "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), } _EMBEDDING_MODELS = { diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py new file mode 100644 index 0000000000000..0c5298eb6f100 --- /dev/null +++ b/vllm/model_executor/models/phi3_small.py @@ -0,0 +1,447 @@ +import math +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput + + +def load_column_parallel_weight(param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + tp = get_tensor_model_parallel_world_size() + rk = get_tensor_model_parallel_rank() + assert param.size(0) * tp == loaded_weight.size(0) + s = rk * param.size(0) + e = (rk + 1) * param.size(0) + loaded_weight = loaded_weight[s:e] + assert param.shape == loaded_weight.shape + param.data.copy_(loaded_weight) + + +class HeadMajorQKVParallelLinear(QKVParallelLinear): + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + return load_column_parallel_weight(param, loaded_weight) + + +class HeadMajorColumnParallelLinear(MergedColumnParallelLinear): + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + return load_column_parallel_weight(param, loaded_weight) + + +@torch.jit.script +def quick_gelu(x): + return x * torch.sigmoid(1.702 * x) + + +@torch.jit.script +def gegelu(input, limit: Optional[float] = None): + a_gelu, a_linear = input[..., ::2], input[..., 1::2] + if limit is not None: + a_gelu = torch.where(torch.isinf(a_gelu), a_gelu, + a_gelu.clamp(min=None, max=limit)) + a_linear = torch.where( + torch.isinf(a_linear), + a_linear, + a_linear.clamp(min=-limit, max=limit), + ) + out_gelu = quick_gelu(a_gelu) + return out_gelu * (a_linear + 1) + + +class Phi3SmallMLP(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + assert (self.config.hidden_act == "gegelu" + ), "Only `gegelu` is supported for the 4.7 series of models .." + self.hidden_size = config.hidden_size + self.gegelu_limit = config.gegelu_limit + self.intermediate_size = config.intermediate_size + + self.up_proj = HeadMajorColumnParallelLinear( + self.hidden_size, + 2 * [self.intermediate_size], + bias=True, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, x): + gate_up, _ = self.up_proj(x) + x = gegelu(gate_up) + x, _ = self.down_proj(x) + return x + + +class Phi3SmallSelfAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.sparse_block_size = config.blocksparse_block_size + self.homo_heads = config.blocksparse_homo_head_pattern + self.local_blocks = config.blocksparse_num_local_blocks + self.vert_stride = config.blocksparse_vert_stride + + assert (config.blocksparse_block_size == + config.blocksparse_triton_kernel_block_size) + + self.hidden_size = config.hidden_size + # Number of Query Heads + self.num_heads = config.num_attention_heads + + self.head_dim = self.hidden_size // self.num_heads + self.tp_size = get_tensor_model_parallel_world_size() + # Number of total Key Value Heads before tensor parallel + self.num_key_value_heads = config.num_key_value_heads + self.num_q_per_kv = self.num_heads // self.num_key_value_heads + if self.tp_size > 1: + assert self.num_key_value_heads % self.tp_size == 0 + self.num_kv_heads_per_partion = max( + 1, self.num_key_value_heads // self.tp_size) + self.num_heads_per_partition = self.num_heads // self.tp_size + + self.max_position_embeddings = config.max_position_embeddings + self.rope_embedding_base = config.rope_embedding_base + self.rope_position_scale = config.rope_position_scale + self.is_causal = True + + norm_factor = None + if config.mup_use_scaling: + norm_factor = self.head_dim / config.mup_attn_multiplier + else: + norm_factor = math.sqrt(self.head_dim) + self.scale = 1 / norm_factor + + self.query_key_value = HeadMajorQKVParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_key_value_heads, + bias=True, + quant_config=quant_config, + ) + + self.dense = RowParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config) + + if getattr(self.config, "rope_scaling", None) is not None: + rope_scaling = self.config.rope_scaling + for key in rope_scaling: + if isinstance(rope_scaling[key], list): + rope_scaling[key] = tuple(rope_scaling[key]) + + if "factor" not in rope_scaling: + rope_scaling["factor"] = self.rope_position_scale + else: + rope_scaling = { + "type": "linear", + "factor": self.rope_position_scale, + } + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_embedding_base, + rope_scaling=rope_scaling, + ) + + # blocksparse params + self.blocksparse_block_size = config.blocksparse_block_size + self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks + self.blocksparse_vert_stride = config.blocksparse_vert_stride + + use_dense_attn = (getattr(self.config, + "dense_attention_every_n_layers", None) + and (self.layer_idx + 1) % + self.config.dense_attention_every_n_layers == 0) + + bs_params = None + if not use_dense_attn: + bs_params = { + 'max_seqlen': self.max_position_embeddings, + 'num_heads': self.num_heads_per_partition, + "num_kv_heads": self.num_kv_heads_per_partion, + "block_size": self.sparse_block_size, + "local_blocks": self.local_blocks, + "vert_stride": self.vert_stride, + "homo_head": self.homo_heads + } + + self.attn = Attention( + self.num_heads_per_partition, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads_per_partion, + cache_config=cache_config, + quant_config=quant_config, + blocksparse_params=bs_params, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + qkv, _ = self.query_key_value(hidden_states) + + qkv = qkv.view(qkv.shape[:-1] + + (-1, (self.num_q_per_kv + 2), self.head_dim)) + q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2) + + # NOTE: this is required by RotaryEmbed, which indeed does not have to + # TODO: allow 3D QK for rotary forward + q = q.reshape(-1, self.head_dim * self.num_heads_per_partition) + k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) + v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata) + output, _ = self.dense(attn_output) + + return output + + +class Phi3SmallDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Phi3SmallSelfAttention(config, + layer_idx, + cache_config=cache_config, + quant_config=quant_config) + self.mlp = Phi3SmallMLP(config, quant_config) + + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Phi3SmallModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.mup_embedding_multiplier = config.mup_embedding_multiplier + self.layers = nn.ModuleList([ + Phi3SmallDecoderLayer(config, layer_idx, cache_config, + quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + + self.final_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata = None, + ): + hidden_states = self.embed_tokens(input_ids) + if (self.mup_embedding_multiplier is not None + and self.mup_embedding_multiplier > 0.0): + hidden_states = hidden_states * self.mup_embedding_multiplier + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + attn_metadata, + ) + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class Phi3SmallForCausalLM(nn.Module): + _tied_weights_keys = ["lm_head.weight"] + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = Phi3SmallModel(config, cache_config, quant_config) + self.vocab_size = config.vocab_size + self.mup_width_multiplier = config.mup_width_multiplier + self.lm_head = ParallelLMHead( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + # tokens in tiktoken but not used + if hasattr(config, 'dummy_token_indices'): + device = self.lm_head.weight.device + self.register_buffer('dummy_token_indices', + torch.LongTensor( + config.dummy_token_indices).to(device), + persistent=False) + else: + self.dummy_token_indices = None + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, value): + self.lm_head = value + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + if self.dummy_token_indices is not None and logits is not None: + logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) + return logits + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + output_hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + output_hidden_states = output_hidden_states + return output_hidden_states + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + + next_tokens = self.sampler(logits / self.mup_width_multiplier, + sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index f36d84dbdf7f9..044eec6410a54 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -63,4 +63,4 @@ def get_hf_text_config(config: PretrainedConfig): assert hasattr(config.text_config, "num_attention_heads") return config.text_config else: - return config + return config \ No newline at end of file From 325c119961698c27d8d11d61d019a6d57c814c51 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 May 2024 23:49:49 -0700 Subject: [PATCH 112/124] [Misc] add logging level env var (#5045) --- .github/ISSUE_TEMPLATE/400-bug report.yml | 2 ++ vllm/envs.py | 5 +++++ vllm/logger.py | 3 ++- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/400-bug report.yml b/.github/ISSUE_TEMPLATE/400-bug report.yml index 08120ad8e5a60..ce980c3f4a01d 100644 --- a/.github/ISSUE_TEMPLATE/400-bug report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug report.yml @@ -59,6 +59,8 @@ body: Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. + Please set the environment variable `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging to help debugging potential issues. + If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs. placeholder: | A clear and concise description of what the bug is. diff --git a/vllm/envs.py b/vllm/envs.py index 56ff79e0cdea9..bef343d08429c 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -22,6 +22,7 @@ VLLM_DO_NOT_TRACK: bool = False VLLM_USAGE_SOURCE: str = "" VLLM_CONFIGURE_LOGGING: int = 1 + VLLM_LOGGING_LEVEL: str = "INFO" VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None @@ -178,6 +179,10 @@ "VLLM_LOGGING_CONFIG_PATH": lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), + # this is used for configuring the default logging level + "VLLM_LOGGING_LEVEL": + lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO"), + # Trace function calls # If set to 1, vllm will trace function calls # Useful for debugging diff --git a/vllm/logger.py b/vllm/logger.py index 153cdfb373bb4..3c6bf0803a624 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -14,6 +14,7 @@ VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH +VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" _DATE_FORMAT = "%m-%d %H:%M:%S" @@ -30,7 +31,7 @@ "vllm": { "class": "logging.StreamHandler", "formatter": "vllm", - "level": "INFO", + "level": VLLM_LOGGING_LEVEL, "stream": "ext://sys.stdout", }, }, From d5a16977729928a2ceafb5ec8764081f40f7cdff Mon Sep 17 00:00:00 2001 From: Lily Liu Date: Sat, 25 May 2024 10:00:14 -0700 Subject: [PATCH 113/124] [Dynamic Spec Decoding] Minor fix for disabling speculative decoding (#5000) --- .../spec_decode/e2e/test_ngram_correctness.py | 41 +++++++++++++++++++ tests/spec_decode/test_dynamic_spec_decode.py | 16 +++++--- vllm/spec_decode/spec_decode_worker.py | 17 +++++--- 3 files changed, 63 insertions(+), 11 deletions(-) diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index c2004ff061a1e..d475d37af6425 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -170,3 +170,44 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator, batch_size, max_output_len=output_len, force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "speculative_disable_by_batch_size": 4 + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that ngram speculative decoding produces exact equality + to without spec decode with many different values of k and + different ngram_prompt_lookup_max. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py index 948a74b22f0ae..48fa862b2e41a 100644 --- a/tests/spec_decode/test_dynamic_spec_decode.py +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest import torch @@ -13,9 +13,9 @@ from .utils import create_batch, mock_worker -@pytest.mark.parametrize('queue_size', [2, 4]) -@pytest.mark.parametrize('batch_size', [1, 2, 3, 6]) -@pytest.mark.parametrize('k', [1, 2, 5, 7, 10]) +@pytest.mark.parametrize('queue_size', [4]) +@pytest.mark.parametrize('batch_size', [1]) +@pytest.mark.parametrize('k', [1]) @torch.inference_mode() def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): """Verify that speculative tokens are disabled when the batch size @@ -42,8 +42,12 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): num_lookahead_slots=k, running_queue_size=queue_size) - with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(execute_model_req=execute_model_req) + if queue_size > disable_by_batch_size: + with patch.object(worker, + '_run_no_spec', + side_effect=ValueError(exception_secret)), \ + pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_req=execute_model_req) # When the batch size is larger than the threshold, # we expect no speculative tokens (0). diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 3462a876c3e90..150e8db0c8aad 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -273,10 +273,17 @@ def execute_model( self._maybe_disable_speculative_tokens( disable_all_speculation, execute_model_req.seq_group_metadata_list) - # If no spec tokens, call the proposer and scorer workers normally. - # Used for prefill. + # Speculative decoding is disabled in the following cases: + # 1. Prefill phase: Speculative decoding is not + # used during the prefill phase. + # 2. Auto-disable enabled: The running queue size exceeds + # the specified threshold. + # 3. No request: There are no requests in the batch. + # In any of these cases, the proposer and scorer workers + # are called normally. if num_lookahead_slots == 0 or len( - execute_model_req.seq_group_metadata_list) == 0: + execute_model_req.seq_group_metadata_list + ) == 0 or disable_all_speculation: return self._run_no_spec(execute_model_req, skip_proposer=disable_all_speculation) @@ -316,8 +323,8 @@ def _maybe_disable_speculative_tokens( @nvtx_range("spec_decode_worker._run_no_spec") def _run_no_spec(self, execute_model_req: ExecuteModelRequest, skip_proposer: bool) -> List[SamplerOutput]: - """Run a prefill step, without any speculation. The input is sent to - the proposer and scorer model so that the KV cache is consistent + """Run a single generation step without any speculation. The input is + sent to the proposer and scorer model so that the KV cache is consistent between the two. When skip_proposer is True, the proposer model is not called, meaning that the kv-cache in proposer for requests is not updated, so they cannot enable spec decode in the rest decoding. From f17a1a8f9665bb237a3dddda7dc93f259e5e81e0 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sat, 25 May 2024 10:28:16 -0700 Subject: [PATCH 114/124] [Misc] Make Serving Benchmark More User-friendly (#5044) --- benchmarks/backend_request_func.py | 6 ++++++ benchmarks/benchmark_serving.py | 29 ++++++++++++++++++++++++++--- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index f9d167590fe47..58dcc6167efa6 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -89,6 +89,9 @@ async def async_request_tgi( output.latency = most_recent_timestamp - st output.success = True output.generated_text = data["generated_text"] + else: + output.error = response.reason or "" + output.success = False except Exception: output.success = False exc_info = sys.exc_info() @@ -276,6 +279,9 @@ async def async_request_openai_completions( output.generated_text = generated_text output.success = True output.latency = latency + else: + output.error = response.reason or "" + output.success = False except Exception: output.success = False exc_info = sys.exc_info() diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 9c3fed4817de2..f3d71de775f82 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -215,6 +215,11 @@ def calculate_metrics( else: actual_output_lens.append(0) + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -226,9 +231,9 @@ def calculate_metrics( 1000, # ttfts is empty if streaming is not supported by backend median_ttft_ms=np.median(ttfts or 0) * 1000, p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, - mean_tpot_ms=np.mean(tpots) * 1000, - median_tpot_ms=np.median(tpots) * 1000, - p99_tpot_ms=np.percentile(tpots, 99) * 1000, + mean_tpot_ms=np.mean(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, ) return metrics, actual_output_lens @@ -250,6 +255,24 @@ async def benchmark( else: raise ValueError(f"Unknown backend: {backend}") + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len = input_requests[0] + test_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + best_of=best_of, + use_beam_search=use_beam_search, + ) + test_output = await request_func(request_func_input=test_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}") + else: + print("Initial test run completed. Starting main benchmark run...") print(f"Traffic request rate: {request_rate}") pbar = None if disable_tqdm else tqdm(total=len(input_requests)) From 1102bef2195a102c6a5489a28329e543a600b4d8 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 27 May 2024 15:18:17 -0700 Subject: [PATCH 115/124] [Bugfix / Core] Prefix Caching Guards (merged with main) (#4846) Co-authored-by: rsnm2 Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> --- .../test_disable_sliding_window.py | 44 ++++++++++++ tests/test_config.py | 24 +++++++ vllm/attention/layer.py | 3 +- vllm/config.py | 67 ++++++++++++++++++- vllm/engine/arg_utils.py | 12 +++- vllm/model_executor/models/llama.py | 4 -- vllm/model_executor/models/mixtral.py | 22 +++--- vllm/model_executor/models/mixtral_quant.py | 4 -- vllm/model_executor/models/qwen2.py | 25 ++++--- vllm/model_executor/models/starcoder2.py | 2 - vllm/model_executor/models/xverse.py | 4 -- 11 files changed, 167 insertions(+), 44 deletions(-) create mode 100644 tests/prefix_caching/test_disable_sliding_window.py diff --git a/tests/prefix_caching/test_disable_sliding_window.py b/tests/prefix_caching/test_disable_sliding_window.py new file mode 100644 index 0000000000000..eeac6ab43c05f --- /dev/null +++ b/tests/prefix_caching/test_disable_sliding_window.py @@ -0,0 +1,44 @@ +"""Compare the with and without prefix caching. + +Run `pytest tests/prefix_caching/test_prefix_caching.py`. +""" +import pytest + +from tests.conftest import cleanup +from vllm import LLM + +MODEL_LEN_LEN = [ + # Example models with sliding window. + ("bigcode/starcoder2-3b", 4096, 16384), + # ("mistralai/Mistral-7B-v0.1", 4096, 32768), << OOM in CI + + # Confirm model with sliding window works. + # config has "use_sliding_window": false + ("Qwen/Qwen1.5-0.5B-Chat", 32768, 32768), + # config has no sliding window attribute. + ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", 2048, 2048), +] + + +@pytest.mark.parametrize("model_len_len", MODEL_LEN_LEN) +def test_disable_sliding_window(model_len_len, ): + model, sliding_len, full_len = model_len_len + vllm_disabled_model = LLM(model, disable_sliding_window=True) + vllm_disabled_model.generate("Hi my name is") + model_config = vllm_disabled_model.llm_engine.model_config + assert model_config.max_model_len == sliding_len, ( + "Max len expected to equal sliding_len of %s, but got %s", sliding_len, + model_config.max_model_len) + + del vllm_disabled_model + cleanup() + + vllm_enabled_model = LLM(model, disable_sliding_window=False) + vllm_enabled_model.generate("Hi my name is") + model_config = vllm_enabled_model.llm_engine.model_config + assert model_config.max_model_len == full_len, ( + "Max len expected to equal full_len of %s, but got %s", full_len, + model_config.max_model_len) + + del vllm_enabled_model + cleanup() diff --git a/tests/test_config.py b/tests/test_config.py index 6bc51a53dc07c..7cbdaeca9c4d4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,29 @@ +import pytest + from vllm.config import ModelConfig +MODEL_IDS_EXPECTED = [ + ("Qwen/Qwen1.5-7B", 32768), + ("mistralai/Mistral-7B-v0.1", 4096), + ("mistralai/Mistral-7B-Instruct-v0.2", 32768), +] + + +@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED) +def test_disable_sliding_window(model_id_expected): + model_id, expected = model_id_expected + model_config = ModelConfig( + model_id, + model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + disable_sliding_window=True, + ) + assert model_config.max_model_len == expected + def test_get_sliding_window(): TEST_SLIDING_WINDOW = 4096 diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index b67f04c51d493..db55a31476fed 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -30,7 +30,6 @@ def __init__( scale: float, num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, blocksparse_params: Optional[Dict[str, Any]] = None, @@ -39,9 +38,11 @@ def __init__( if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size + sliding_window = cache_config.sliding_window else: kv_cache_dtype = "auto" block_size = 16 + sliding_window = None if num_kv_heads is None: num_kv_heads = num_heads diff --git a/vllm/config.py b/vllm/config.py index b245a1a3ee6d3..4b256d00a32df 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -69,6 +69,10 @@ class ModelConfig: max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode + disable_sliding_window: Whether to disable sliding window. If True, + we will disable the sliding window functionality of the model. + If the model does not support sliding window, this argument is + ignored. skip_tokenizer_init: If true, skip initialization of tokenizer and detokenizer. served_model_name: The model name used in metrics tag `model_name`, @@ -96,6 +100,7 @@ def __init__( max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 5, + disable_sliding_window: bool = False, skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, List[str]]] = None, ) -> None: @@ -118,14 +123,18 @@ def __init__( self.max_seq_len_to_capture = (max_seq_len_to_capture or max_context_len_to_capture) self.max_logprobs = max_logprobs + self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, code_revision, rope_scaling) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) - self.max_model_len = _get_and_verify_max_len(self.hf_text_config, - max_model_len) + self.max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window()) self.served_model_name = get_served_model_name(model, served_model_name) if not self.skip_tokenizer_init: @@ -220,7 +229,7 @@ def verify_with_parallel_config( "must be divisible by pipeline parallel size " f"({pipeline_parallel_size}).") - def get_sliding_window(self) -> Optional[int]: + def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled. """ @@ -232,6 +241,15 @@ def get_sliding_window(self) -> Optional[int]: return None return getattr(self.hf_text_config, "sliding_window", None) + def get_sliding_window(self) -> Optional[int]: + """Get the sliding window size, or None if disabled. + """ + # If user disables sliding window, return None. + if self.disable_sliding_window: + return None + # Otherwise get the value from the hf config. + return self.get_hf_config_sliding_window() + def get_vocab_size(self) -> int: return self.hf_text_config.vocab_size @@ -336,6 +354,7 @@ def __init__( self.enable_prefix_caching = enable_prefix_caching self._verify_args() self._verify_cache_dtype() + self._verify_prefix_caching() # Will be set after profiling. self.num_gpu_blocks = None @@ -364,6 +383,19 @@ def _verify_cache_dtype(self) -> None: else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + def _verify_prefix_caching(self) -> None: + if not self.enable_prefix_caching: + return + + if self.sliding_window is not None: + raise NotImplementedError( + "Prefix caching is not supported with sliding window. " + "Run with --disable-sliding-window to use prefix caching.") + if self.cache_dtype == "fp8": + raise NotImplementedError( + "Prefix caching is not supported for fp8 cache_dtype. " + "Run with --kv-cache-dtype auto to use prefix caching.") + def verify_with_parallel_config( self, parallel_config: "ParallelConfig", @@ -1116,6 +1148,8 @@ def _get_and_verify_dtype( def _get_and_verify_max_len( hf_config: PretrainedConfig, max_model_len: Optional[int], + disable_sliding_window: bool, + sliding_window_len: Optional[int], ) -> int: """Get and verify the model's maximum length.""" derived_max_model_len = float("inf") @@ -1135,6 +1169,7 @@ def _get_and_verify_max_len( "max_seq_length", "seq_len", ] + # Choose the smallest "max_length" from the possible keys. max_len_key = None for key in possible_keys: max_len = getattr(hf_config, key, None) @@ -1142,6 +1177,16 @@ def _get_and_verify_max_len( max_len_key = key if max_len < derived_max_model_len \ else max_len_key derived_max_model_len = min(derived_max_model_len, max_len) + + # If sliding window is manually disabled, max_length should be less + # than the sliding window length in the model config. + if disable_sliding_window and sliding_window_len is not None: + max_len_key = "sliding_window" \ + if sliding_window_len < derived_max_model_len else max_len_key + derived_max_model_len = min(derived_max_model_len, sliding_window_len) + + # If none of the keys were found in the config, use a default and + # log a warning. if derived_max_model_len == float("inf"): if max_model_len is not None: # If max_model_len is specified, we use it. @@ -1157,6 +1202,13 @@ def _get_and_verify_max_len( rope_scaling = getattr(hf_config, "rope_scaling", None) if rope_scaling is not None and rope_scaling["type"] != "su": + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate.") assert "factor" in rope_scaling scaling_factor = rope_scaling["factor"] if rope_scaling["type"] == "yarn": @@ -1164,6 +1216,8 @@ def _get_and_verify_max_len( "original_max_position_embeddings"] derived_max_model_len *= scaling_factor + # If the user specified a max length, make sure it is smaller than the + # derived length from the HF model config. if max_model_len is None: max_model_len = int(derived_max_model_len) elif max_model_len > derived_max_model_len: @@ -1172,6 +1226,13 @@ def _get_and_verify_max_len( # with model_max_length and allow this override when it's smaller. model_max_length = getattr(hf_config, "model_max_length", None) if model_max_length is not None and max_model_len <= model_max_length: + if disable_sliding_window: + # TODO(robertgshaw): Find a model that has model_max_length + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "model_max_length in the config. Please raise an issue " + "so we can investigate.") pass else: raise ValueError( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 538e3427e37fb..3267c8c9f44d2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -41,6 +41,7 @@ class EngineArgs: max_parallel_loading_workers: Optional[int] = None block_size: int = 16 enable_prefix_caching: bool = False + disable_sliding_window: bool = False use_v2_block_manager: bool = False swap_space: int = 4 # GiB gpu_memory_utilization: float = 0.90 @@ -267,6 +268,10 @@ def add_cli_args( parser.add_argument('--enable-prefix-caching', action='store_true', help='Enables automatic prefix caching.') + parser.add_argument('--disable-sliding-window', + action='store_true', + help='Disables sliding window, ' + 'capping to sliding window size') parser.add_argument('--use-v2-block-manager', action='store_true', help='Use BlockSpaceMangerV2.') @@ -558,8 +563,8 @@ def create_engine_config(self, ) -> EngineConfig: self.max_model_len, self.quantization, self.quantization_param_path, self.enforce_eager, self.max_context_len_to_capture, self.max_seq_len_to_capture, - self.max_logprobs, self.skip_tokenizer_init, - self.served_model_name) + self.max_logprobs, self.disable_sliding_window, + self.skip_tokenizer_init, self.served_model_name) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, @@ -645,7 +650,8 @@ def create_engine_config(self, ) -> EngineConfig: if (model_config.get_sliding_window() is not None and scheduler_config.chunked_prefill_enabled): raise ValueError( - "Chunked prefill is not supported with sliding window.") + "Chunked prefill is not supported with sliding window. " + "Set --disable-sliding-window to disable sliding window.") return EngineConfig(model_config=model_config, cache_config=cache_config, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 086f9294c4f1c..2ca55f9270fc7 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -94,7 +94,6 @@ def __init__( max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, - sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() @@ -146,7 +145,6 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window, cache_config=cache_config, quant_config=quant_config) @@ -183,7 +181,6 @@ def __init__( config.original_max_position_embeddings) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - sliding_window = getattr(config, "sliding_window", None) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( @@ -198,7 +195,6 @@ def __init__( max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=attention_bias, - sliding_window=sliding_window, cache_config=cache_config, ) self.mlp = LlamaMLP( diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index ea95cf7380d54..d6dd7fa1fe9e2 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -246,15 +246,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class MixtralAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -276,7 +277,6 @@ def __init__(self, self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.sliding_window = sliding_window if isinstance( quant_config, @@ -312,7 +312,6 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, cache_config=cache_config, quant_config=quant_config) @@ -349,7 +348,6 @@ def __init__( max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - sliding_window=config.sliding_window, cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE( diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 9b99ff729aadd..1894c05e167d6 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -166,7 +166,6 @@ def __init__( max_position: int = 4096 * 32, rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() @@ -190,7 +189,6 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.sliding_window = sliding_window self.qkv_proj = QKVParallelLinear( hidden_size, @@ -217,7 +215,6 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, cache_config=cache_config, quant_config=quant_config) @@ -254,7 +251,6 @@ def __init__( max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - sliding_window=config.sliding_window, cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE(config=config, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index ec203c3b9001a..9a4829a27873e 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -86,10 +86,8 @@ def __init__(self, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - use_sliding_window: bool = False, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None, rope_scaling: Optional[Tuple] = None) -> None: super().__init__() self.hidden_size = hidden_size @@ -112,7 +110,6 @@ def __init__(self, self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.sliding_window = sliding_window if use_sliding_window else None self.qkv_proj = QKVParallelLinear( hidden_size, @@ -140,7 +137,6 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, cache_config=cache_config, quant_config=quant_config) @@ -164,7 +160,6 @@ class Qwen2DecoderLayer(nn.Module): def __init__( self, config: Qwen2Config, - layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -173,18 +168,14 @@ def __init__( # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) - use_sliding_window = (config.use_sliding_window - and layer_idx < config.max_window_layers) self.self_attn = Qwen2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - use_sliding_window=use_sliding_window, cache_config=cache_config, quant_config=quant_config, - sliding_window=config.sliding_window, rope_scaling=rope_scaling) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, @@ -244,8 +235,8 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2DecoderLayer(config, layer_idx, cache_config, quant_config) - for layer_idx in range(config.num_hidden_layers) + Qwen2DecoderLayer(config, cache_config, quant_config) + for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -302,6 +293,18 @@ def __init__( lora_config: Optional[LoRAConfig] = None, ) -> None: del lora_config + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + raise ValueError("Sliding window for some but all layers is not " + "supported. This model uses sliding window " + "but `max_window_layers` = %s is less than " + "`num_hidden_layers` = %s. Please open an issue " + "to discuss this feature." % ( + config.max_window_layers, + config.num_hidden_layers, + )) + super().__init__() self.config = config self.quant_config = quant_config diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 91ffd0861c39d..4324bf50d4ad1 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -74,7 +74,6 @@ def __init__(self, self.rope_theta = config.rope_theta self.max_position_embeddings = config.max_position_embeddings self.use_bias = config.use_bias - self.sliding_window = config.sliding_window self.qkv_proj = QKVParallelLinear( self.hidden_size, @@ -101,7 +100,6 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, cache_config=cache_config, quant_config=quant_config) diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index dda13d83f89a3..1e5280dde3ff9 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -88,7 +88,6 @@ def __init__( max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, - sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() @@ -134,7 +133,6 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window, cache_config=cache_config, quant_config=quant_config) @@ -167,7 +165,6 @@ def __init__( rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - sliding_window = getattr(config, "sliding_window", None) self.self_attn = XverseAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -178,7 +175,6 @@ def __init__( max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=getattr(config, "bias", False), - sliding_window=sliding_window, cache_config=cache_config, ) self.mlp = XverseMLP( From fbdb7b3ee2c3f7b58841c12b52ed4c83f508babb Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Mon, 27 May 2024 22:26:14 +0000 Subject: [PATCH 116/124] [Core] Allow AQLM on Pascal (#5058) --- vllm/model_executor/layers/quantization/aqlm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index 83e24fadc1405..730595c3d36d1 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -192,7 +192,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 70 + return 60 @classmethod def get_config_filenames(cls) -> List[str]: From 890aa93d275a2b75313629614ab9ed278a13f6d7 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 28 May 2024 07:41:43 +0800 Subject: [PATCH 117/124] [Model] Add support for falcon-11B (#5069) --- vllm/model_executor/models/falcon.py | 55 ++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index ba707adb03dfe..9618652f70d23 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -41,7 +41,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput @@ -246,18 +246,26 @@ def __init__( self.mlp = FalconMLP(config, quant_config) self.config = config - if config.new_decoder_architecture: - # The layer norm before self-attention - self.ln_attn = LayerNorm(hidden_size, - eps=config.layer_norm_epsilon) - # The layer norm before the MLP - self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - else: + if (config.num_ln_in_parallel_attn is None + and config.new_decoder_architecture): + config.num_ln_in_parallel_attn = 2 + + if not config.parallel_attn: + self.post_attention_layernorm = LayerNorm( + hidden_size, eps=config.layer_norm_epsilon) self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - if not config.parallel_attn: - self.post_attention_layernorm = LayerNorm( - hidden_size, eps=config.layer_norm_epsilon) + else: + if config.num_ln_in_parallel_attn == 2: + # The layer norm before self-attention + self.ln_attn = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + # The layer norm before the MLP + self.ln_mlp = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + else: + self.input_layernorm = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) self.reduce_row_parallel_results = not (config.new_decoder_architecture or config.parallel_attn) @@ -271,7 +279,7 @@ def forward( ) -> torch.Tensor: residual = hidden_states - if self.config.new_decoder_architecture: + if self.config.num_ln_in_parallel_attn == 2: attention_layernorm_out = self.ln_attn(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states) else: @@ -294,6 +302,10 @@ def forward( residual += attention_output mlp_layernorm_out = self.post_attention_layernorm(residual) + if (self.config.new_decoder_architecture and self.config.parallel_attn + and self.config.num_ln_in_parallel_attn == 1): + mlp_layernorm_out = attention_layernorm_out + # MLP. mlp_output, mlp_bias = self.mlp(mlp_layernorm_out) if self.reduce_row_parallel_results and mlp_bias is not None: @@ -375,7 +387,20 @@ def __init__( self.config = config self.quant_config = quant_config self.transformer = FalconModel(config, cache_config, quant_config) - self.lm_head_weight = self.transformer.word_embeddings.weight + # only Falcon-11B doesn't share lm_head weight with word embeddings + # and previous Falcon model doesn't have tie_word_embeddings config + # so we set tie_word_embeddings to True by default + self.tie_word_embeddings = (config.tie_word_embeddings + if config.tie_word_embeddings is not None + else True) + if self.tie_word_embeddings: + self.lm_head_weight = self.transformer.word_embeddings.weight + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + ) + self.lm_head_weight = self.lm_head.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -419,8 +444,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: - if name == "lm_head.weight": - # Falcon uses tied embeddings. + if name == "lm_head.weight" and self.tie_word_embeddings: + # Falcon uses tied embeddings except Falcon-11b. continue # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: From d4f398590786f0015d474b03a3d078db1e7d1be2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Moskal?= Date: Mon, 27 May 2024 19:07:07 -0700 Subject: [PATCH 118/124] [Core] Sliding window for block manager v2 (#4545) Co-authored-by: Ruth Evans --- tests/core/block/e2e/conftest.py | 26 +++ tests/core/block/e2e/test_correctness.py | 11 +- .../e2e/test_correctness_sliding_window.py | 168 ++++++++++++++++++ tests/core/block/test_block_manager_v2.py | 69 +++++++ vllm/attention/ops/prefix_prefill.py | 6 +- vllm/core/block/block_table.py | 34 +++- vllm/core/block/cpu_gpu_block_allocator.py | 74 ++++++++ vllm/core/block/interfaces.py | 9 + vllm/core/block_manager_v2.py | 24 ++- vllm/engine/arg_utils.py | 3 +- vllm/worker/cache_engine.py | 5 +- vllm/worker/model_runner.py | 73 +++++--- 12 files changed, 457 insertions(+), 45 deletions(-) create mode 100644 tests/core/block/e2e/test_correctness_sliding_window.py diff --git a/tests/core/block/e2e/conftest.py b/tests/core/block/e2e/conftest.py index b0d62c8993d3f..e870597b7a011 100644 --- a/tests/core/block/e2e/conftest.py +++ b/tests/core/block/e2e/conftest.py @@ -1,3 +1,5 @@ +from typing import Callable, Iterable, Optional + import pytest from vllm import LLM @@ -40,3 +42,27 @@ def generator_inner(): for llm in generator_inner(): yield llm del llm + + +def get_text_from_llm_generator(llm_generator: Iterable[LLM], + prompts, + sampling_params, + llm_cb: Optional[Callable[[LLM], + None]] = None): + for llm in llm_generator: + if llm_cb: + llm_cb(llm) + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + text = [output.outputs[0].text for output in outputs] + del llm + + return text + + +def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): + for llm in llm_generator: + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + token_ids = [output.outputs[0].token_ids for output in outputs] + del llm + + return token_ids diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index c3666da7542b5..3713ef2fed4d1 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -4,6 +4,8 @@ from vllm import SamplingParams +from .conftest import get_token_ids_from_llm_generator + @pytest.mark.parametrize( "common_llm_kwargs", @@ -444,12 +446,3 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator, assert expected_token_ids == actual_token_ids assert baseline_token_ids == test_token_ids - - -def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): - for llm in llm_generator: - outputs = llm.generate(prompts, sampling_params, use_tqdm=True) - token_ids = [output.outputs[0].token_ids for output in outputs] - del llm - - return token_ids diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py new file mode 100644 index 0000000000000..e98292e807d73 --- /dev/null +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -0,0 +1,168 @@ +import random +from typing import List + +import pytest + +from vllm import LLM, SamplingParams + +from .conftest import get_text_from_llm_generator + +# relatively small model with 4k sliding window +MODEL = "bigcode/starcoder2-3b" +BLOCK_SIZE = 16 + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": MODEL, + + # skip cuda graph creation for fast test. + "enforce_eager": True, + "block_size": BLOCK_SIZE, + # needed due to https://github.com/vllm-project/vllm/issues/1908#issuecomment-2101122008 + "num_gpu_blocks_override": 100000 // BLOCK_SIZE, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "use_v2_block_manager": False +}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) +@pytest.mark.parametrize("batch_size", [5]) +@pytest.mark.parametrize("seed", [1]) +def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, + batch_size, seed): + """ + The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then + asks for value of one of them (which is outside the sliding window). + If we tell it upfront which we are going to be looking for, then + it answers correctly (mostly). + + Additionally, we compare the results of the v1 and v2 managers. + """ + sampling_params = SamplingParams( + max_tokens=1024, + ignore_eos=True, + temperature=0.0, + ) + + prompts, answer, indices = prep_prompts(batch_size) + + print('Getting token ids from block manager v1') + baseline_texts = get_text_from_llm_generator(baseline_llm_generator, + prompts, + sampling_params, + llm_cb=check_window(prompts)) + + check_answers(indices, answer, baseline_texts) + + print('Getting token ids from block manager v2') + test_texts = get_text_from_llm_generator(test_llm_generator, prompts, + sampling_params) + check_answers(indices, answer, test_texts) + + cmp = [ + expected_text == actual_text + for expected_text, actual_text in zip(baseline_texts, test_texts) + ] + print(cmp) + # make sure it's mostly OK; this is possibly because https://github.com/vllm-project/vllm/pull/4768 + # however, https://github.com/vllm-project/vllm/issues/3385#issuecomment-1995924290 + # states that xformers and flash_attn have different ideas about the window + # size anyways + assert sum(cmp) > 0.7 * len(cmp) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": MODEL, + + # skip cuda graph creation for fast test. + "enforce_eager": True, + "block_size": BLOCK_SIZE, + "num_gpu_blocks_override": 100000 // BLOCK_SIZE, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "use_v2_block_manager": True, + "enable_chunked_prefill": True +}]) +@pytest.mark.parametrize("batch_size", [5]) +@pytest.mark.parametrize("seed", [1]) +def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed): + """ + This is similar to test_sliding_window_retrival, however, it doesn't + compare against the v1 block manager since v1 doesn't support + chunked prefill with sliding window. + + The results with and without chunked prefill are not the same due to + numerical instabilities. + """ + sampling_params = SamplingParams( + max_tokens=10, + ignore_eos=True, + temperature=0.0, + ) + + prompts, answer, indices = prep_prompts(batch_size) + + # We don't compare with the baseline model here, since the results + # slightly different due to different tailing in attention. + test_texts = get_text_from_llm_generator(test_llm_generator, + prompts, + sampling_params, + llm_cb=check_window(prompts)) + check_answers(indices, answer, test_texts) + + +def prep_prompts(batch_size: int): + """ + Generate prompts which a bunch of assignments, + then asking for the value of one of them. + The prompt is just under 10k tokens; sliding window is 4k + so the answer is outside sliding window, but should still be correct. + """ + prompts: List[str] = [] + answer: List[int] = [] + indices: List[int] = [] + random.seed(1) + for _ in range(batch_size): + idx = random.randint(30, 90) + indices.append(idx) + prompt = "```python\n# We set a number of variables, " + \ + f"x{idx} will be important later\n" + ln = random.randint(800, 1100) + for k in range(30, ln): + v = random.randint(10, 99) + if k == idx: + answer.append(v) + prompt += f"x{k} = {v}\n" + prompt += f"# Now, we check the value of x{idx}:\n" + prompt += f"assert x{idx} == " + prompts.append(prompt) + return prompts, answer, indices + + +def check_answers(indices: List[int], answer: List[int], outputs: List[str]): + answer2 = [int(text[0:2].strip()) for text in outputs] + print(list(zip(indices, zip(answer, answer2)))) + numok = 0 + for a1, a2 in zip(answer, answer2): + if a1 == a2: + numok += 1 + frac_ok = numok / len(answer) + print(f"Num OK: {numok}/{len(answer)} {frac_ok}") + assert frac_ok > 0.7 + + +def check_window(prompts: List[str]): + + def inner(llm: LLM): + sliding_window = llm.llm_engine.model_config.get_sliding_window() + assert sliding_window and sliding_window > 0 + assert any( + len(llm.get_tokenizer().tokenize(prompt)) > sliding_window + for prompt in prompts) + + return inner diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 1e8e4ccdfb151..91b047f0e183e 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -101,3 +101,72 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append, range(prompt_len + num_slots_to_append + num_lookahead_slots)), block_size)) - len(chunk_list(list(range(prompt_len)), block_size)) assert num_consumed_blocks == expected_consumed_blocks + + +@pytest.mark.parametrize("block_size", [8, 16]) +@pytest.mark.parametrize("prompt_len", [10, 300, 1000]) +@pytest.mark.parametrize("num_slots_to_append", [50]) +@pytest.mark.parametrize("sliding_window", [20, 32, 200, 512]) +def test_sliding_window(block_size, prompt_len, num_slots_to_append, + sliding_window): + """Verify append_slots consumes the correct number of blocks from the block + table. + """ + + num_gpu_blocks = 1024 + watermark = 0.1 + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=0, + watermark=watermark, + sliding_window=sliding_window, + ) + + def check_used(min_n, max_n=None): + if max_n is None: + max_n = min_n + used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks() + #print("check", min_n, used, max_n) + assert min_n <= used + assert used <= max_n + + def num_blocks(num_tokens): + return (num_tokens + block_size - 1) // block_size + + check_used(0) + + seq_group = create_seq_group( + seq_prompt_len=prompt_len, + seq_output_lens=[0], + ) + + check_used(0) + + # Allocate seq + assert block_manager.can_allocate(seq_group) + block_manager.allocate(seq_group) + + check_used(num_blocks(prompt_len)) + + # Seq seq to RUNNING + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + seq.data.update_num_computed_tokens(prompt_len) + check_used(num_blocks(prompt_len)) + + # this is how we compute it in BlockSpaceManagerV2.__init__ + sliding_blocks = (sliding_window // block_size) + 2 + # plus one block for null block + sliding_blocks += 1 + + # Append tokens to the sequeqnce + for token_id in range(num_slots_to_append): + seq.append_token_id(token_id, {token_id: Logprob(0.0)}) + seq.data.update_num_computed_tokens(1) + block_manager.append_slots(seq, num_lookahead_slots=0) + if prompt_len < sliding_window + 10: + check_used(0, sliding_blocks + 1) + else: + check_used(sliding_blocks, sliding_blocks + 1) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 997b25e887e30..b99cf9a50d105 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -697,6 +697,10 @@ def context_attention_fwd(q, grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + num_warps = 8 if Lk <= 64 else 8 if alibi_slopes is not None: _fwd_kernel_alibi[grid]( @@ -794,7 +798,7 @@ def context_attention_fwd(q, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, - SLIDING_WINDOW=sliding_window if sliding_window is not None else 0, + SLIDING_WINDOW=sliding_window, num_warps=num_warps, num_stages=1, ) diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index b0d9511fba521..26c704b8de901 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -20,6 +20,10 @@ class BlockTable: _blocks (Optional[List[Block]], optional): An optional list of existing blocks to initialize the BlockTable with. If not provided, an empty BlockTable is created. + max_block_sliding_window (Optional[int], optional): The number of + blocks to keep around for each sequance. If None, all blocks + are kept (eg., when sliding window is not used). + It should at least fit the sliding window size of the model. Attributes: _block_size (int): The maximum number of tokens that can be stored in a @@ -37,6 +41,7 @@ def __init__( block_size: int, block_allocator: DeviceAwareBlockAllocator, _blocks: Optional[List[Block]] = None, + max_block_sliding_window: Optional[int] = None, ): self._block_size = block_size self._allocator = block_allocator @@ -44,6 +49,7 @@ def __init__( _blocks = [] self._blocks: List[Block] = _blocks + self._max_block_sliding_window = max_block_sliding_window # Use helper method instead of directly calculating, as blocks # may not be allocated. self._num_full_slots = len(self._get_all_token_ids()) @@ -89,7 +95,8 @@ def allocate(self, def append_token_ids(self, token_ids: List[int], - num_lookahead_slots: int = 0) -> None: + num_lookahead_slots: int = 0, + num_computed_slots: Optional[int] = None) -> None: """Appends a sequence of token IDs to the existing blocks in the BlockTable. @@ -104,13 +111,35 @@ def append_token_ids(self, Args: token_ids (List[int]): The sequence of token IDs to be appended. + num_computed_slots (Optional[int]): The number of KV cache slots + that are already filled (computed). + When sliding window is enabled, this is used to compute how many + blocks to drop at the front of the sequence. + Without sliding window, None can be passed. + Without chunked prefill, it should be the same as + _num_full_slots. """ - assert self._is_allocated + assert self._is_allocated, "no blocks have been allocated" assert len(self._blocks) > 0 + # Drop blocks that are no longer needed due to sliding window + if self._max_block_sliding_window is not None: + null_block = self._allocator.allocate_or_get_null_block() + assert num_computed_slots is not None + end_block_idx = (num_computed_slots // + self._block_size) - self._max_block_sliding_window + for idx in range(0, end_block_idx): + b = self._blocks[idx] + if b is not null_block: + self._allocator.free(b) + self._blocks[idx] = null_block + + # Ensure there are enough empty slots for the new tokens plus + # lookahead slots self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + num_lookahead_slots) + # Update the blocks with the new tokens blocks = self._blocks[self._num_full_slots // self._block_size:] token_blocks = self._chunk_token_blocks_for_append(token_ids) @@ -168,6 +197,7 @@ def fork(self) -> "BlockTable": block_size=self._block_size, block_allocator=self._allocator, _blocks=forked_blocks, + max_block_sliding_window=self._max_block_sliding_window, ) def free(self) -> None: diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 0577ca76ea971..d28a684376974 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -105,11 +105,19 @@ def __init__( Device.GPU: gpu_block_allocator, } + self._null_block: Optional[Block] = None + self._block_ids_to_allocator: Dict[int, BlockAllocator] = {} for _, allocator in self._allocators.items(): for block_id in allocator.all_block_ids: self._block_ids_to_allocator[block_id] = allocator + def allocate_or_get_null_block(self) -> Block: + if self._null_block is None: + self._null_block = NullBlock( + self.allocate_mutable(None, Device.GPU)) + return self._null_block + def allocate_mutable(self, prev_block: Optional[Block], device: Device) -> Block: """Allocates a new mutable block on the specified device. @@ -149,6 +157,9 @@ def free(self, block: Block) -> None: Args: block (Block): The block to be freed. """ + # Null block should never be freed + if isinstance(block, NullBlock): + return block_id = block.block_id assert block_id is not None allocator = self._block_ids_to_allocator[block_id] @@ -165,6 +176,8 @@ def fork(self, last_block: Block) -> List[Block]: List[Block]: A new list of blocks that shares the same memory as the original sequence. """ + # do not attempt to fork the null block + assert not isinstance(last_block, NullBlock) block_id = last_block.block_id assert block_id is not None allocator = self._block_ids_to_allocator[block_id] @@ -226,3 +239,64 @@ def promote_to_immutable_block(self, block: Block) -> BlockId: def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: raise NotImplementedError + + +class NullBlock(Block): + """ + Null blocks are used as a placeholders for KV cache blocks that have + been dropped due to sliding window. + This implementation just wraps an ordinary block and prevents it from + being modified. It also allows for testing if a block is NullBlock + via isinstance(). + """ + + def __init__(self, proxy: Block): + super().__init__() + self._proxy = proxy + + def append_token_ids(self, token_ids: List[BlockId]): + raise ValueError("null block should not be modified") + + @property + def block_id(self): + return self._proxy.block_id + + @block_id.setter + def block_id(self, value: Optional[BlockId]): + raise ValueError("null block should not be modified") + + @property + def token_ids(self) -> List[BlockId]: + return self._proxy.token_ids + + @property + def num_empty_slots(self) -> BlockId: + return self._proxy.num_empty_slots + + @property + def is_full(self): + return self._proxy.is_full + + @property + def prev_block(self): + return self._proxy.prev_block + + @property + def computed(self): + return self._proxy.computed + + @computed.setter + def computed(self, value): + self._proxy.computed = value + + @property + def last_accessed(self) -> float: + return self._proxy.last_accessed + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + self._proxy.last_accessed = last_accessed_ts + + @property + def content_hash(self): + return self._proxy.content_hash diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 140fbbb0949cc..8fc4c601106cd 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -203,3 +203,12 @@ def mark_blocks_as_computed(self, block_ids: List[int]) -> None: def get_common_computed_block_ids( self, seq_block_ids: List[List[int]]) -> List[int]: pass + + @abstractmethod + def allocate_or_get_null_block(self) -> Block: + """ + Null blocks are used as a placeholders for KV cache blocks that have + been dropped due to sliding window. + There is at most one null block per allocator. + """ + pass diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index f0bc96564050a..834436c25e160 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -66,9 +66,18 @@ def __init__( self.num_total_gpu_blocks = num_gpu_blocks self.num_total_cpu_blocks = num_cpu_blocks - assert sliding_window is None, "Sliding window not yet supported" - - self.block_sliding_window = None + self.sliding_window = sliding_window + # max_block_sliding_window is the max number of blocks that need to be + # allocated + self.max_block_sliding_window = None + if sliding_window is not None: + # +1 here because // rounds down + num_blocks = sliding_window // block_size + 1 + # +1 here because the last block may not be full, + # and so the sequence stretches one more block at the beginning + # For example, if sliding_window is 3 and block_size is 4, + # we may need 2 blocks when the second block only holds 1 token. + self.max_block_sliding_window = num_blocks + 1 self.watermark = watermark assert watermark >= 0.0 @@ -96,10 +105,9 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: block_size=self.block_size, ) - assert self.block_sliding_window is None - if self.block_sliding_window is not None: + if self.max_block_sliding_window is not None: num_required_blocks = min(num_required_blocks, - self.block_sliding_window) + self.max_block_sliding_window) num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( device=Device.GPU) @@ -125,8 +133,9 @@ def allocate(self, seq_group: SequenceGroup) -> None: block_table = BlockTable( block_size=self.block_size, block_allocator=self.block_allocator, + max_block_sliding_window=self.max_block_sliding_window, ) - assert self.block_sliding_window is None + block_table.allocate(seq.get_token_ids()) self.block_tables[seq.seq_id] = block_table @@ -174,6 +183,7 @@ def append_slots( block_table.append_token_ids( token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), num_lookahead_slots=num_lookahead_slots, + num_computed_slots=seq.data.get_num_computed_tokens(), ) # Return any new copy-on-writes. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3267c8c9f44d2..11485aa2438c0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -648,7 +648,8 @@ def create_engine_config(self, ) -> EngineConfig: guided_decoding_backend=self.guided_decoding_backend) if (model_config.get_sliding_window() is not None - and scheduler_config.chunked_prefill_enabled): + and scheduler_config.chunked_prefill_enabled + and not scheduler_config.use_v2_block_manager): raise ValueError( "Chunked prefill is not supported with sliding window. " "Set --disable-sliding-window to disable sliding window.") diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 07d51dca226bd..2f0e59f7ae7c9 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -68,8 +68,11 @@ def _allocate_kv_cache( pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_layers): + # null block in CpuGpuBlockAllocator requires at least that + # block to be zeroed-out. + # We zero-out everything for simplicity. kv_cache.append( - torch.empty(kv_cache_shape, + torch.zeros(kv_cache_shape, dtype=self.dtype, pin_memory=pin_memory, device=device)) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 87d5f5c1b9d67..5ddd2d1b65f81 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -269,6 +269,12 @@ def _prepare_model_input( if len(seq_group_metadata_list) == 0: return ModelInput.empty(self.device) + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window + self.block_size - + 1) // self.block_size + block_aligned_sliding_window = \ + sliding_window_blocks * self.block_size + for seq_group_metadata in seq_group_metadata_list: seq_ids = list(seq_group_metadata.seq_data.keys()) is_prompt = seq_group_metadata.is_prompt @@ -309,6 +315,30 @@ def _prepare_model_input( and self.sliding_window is None and is_prompt) + # These are seq_len/context_len capped to the sliding window. + # They are passed to decode kernel. + # We still need original seq_len/context_len to compute slot + # mapping (and input position) below. + curr_sliding_window_blocks = None + sliding_seq_len = seq_len + sliding_context_len = context_len + + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if (self.sliding_window is not None and not is_prompt): + curr_sliding_window_blocks = sliding_window_blocks + if self.scheduler_config.use_v2_block_manager: + # number of elements in last block + suff_len = seq_len % self.block_size + sliding_seq_len = min( + seq_len, block_aligned_sliding_window + suff_len) + if suff_len > 0: + curr_sliding_window_blocks += 1 + else: + sliding_seq_len = min(seq_len, self.sliding_window) + sliding_context_len = sliding_seq_len - 1 + # TODO(sang): Combine chunked prefill and prefix caching by # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. @@ -316,6 +346,13 @@ def _prepare_model_input( assert computed_block_nums is not None context_len = len(computed_block_nums) * self.block_size tokens = tokens[context_len:] + + # need to think what to set it to when we have both sliding + # window and prefix caching... + assert self.sliding_window is None, \ + "Prefix caching is not supported with sliding window" + sliding_context_len = context_len + if self.attn_backend.get_name() == "flash-attn": # NOTE(woosuk): For flash-attn, the block table should # include the entries for the incoming prefill tokens. @@ -329,14 +366,9 @@ def _prepare_model_input( if seq_group_metadata.block_tables is not None: # chunked prefill or decode block_table = seq_group_metadata.block_tables[seq_id] - if self.sliding_window is not None: - # chunked prefill doesn't support sliding window. - assert (not self.scheduler_config. - chunked_prefill_enabled) - sliding_window_blocks = (self.sliding_window // - self.block_size) - block_table = block_table[-sliding_window_blocks:] - + if curr_sliding_window_blocks is not None: + block_table = block_table[ + -curr_sliding_window_blocks:] if self.attn_backend.get_name() == "flashinfer": paged_kv_indices.extend(block_table) paged_kv_indptr.append(paged_kv_indptr[-1] + @@ -354,16 +386,9 @@ def _prepare_model_input( block_table = [] block_tables.append(block_table) - # TODO(sang): This is a hack to make sliding window work with - # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. - if (self.sliding_window is not None and not is_prompt): - seq_len = min(seq_len, self.sliding_window) - context_len = seq_len - 1 - - seq_lens.append(seq_len) - context_lens.append(context_len) - query_len = seq_len - context_len + seq_lens.append(sliding_seq_len) + context_lens.append(sliding_context_len) + query_len = sliding_seq_len - sliding_context_len query_lens.append(query_len) input_tokens.extend(tokens) input_positions.extend(list(range(context_len, seq_len))) @@ -380,16 +405,15 @@ def _prepare_model_input( "seq_len: {}, context_len: {}, query_len: {}".format( seq_len, context_len, query_len)) num_decode_tokens += query_len - decode_seq_lens.append(seq_len) + decode_seq_lens.append(sliding_seq_len) if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping += [lora_id] * (seq_len - context_len) + lora_index_mapping += [lora_id] * query_len lora_prompt_mapping.extend( [lora_id] * - (seq_len - - context_len if seq_group_metadata.sampling_params + (query_len if seq_group_metadata.sampling_params and seq_group_metadata.sampling_params.prompt_logprobs else 1)) @@ -417,9 +441,10 @@ def _prepare_model_input( start_idx = 0 if self.sliding_window is not None: if is_prompt: - assert context_len == 0, ( + assert self.scheduler_config.use_v2_block_manager \ + or context_len == 0, ( "Prefix caching is currently not supported with " - "sliding window attention") + "sliding window attention in V1 block manager") # It is an optimization. When it is decoding, it is always # 0. When prefill, we use it to not write slots to kv cache # to save memory. From 9ba415588aeda8d99bda8889f90010f0d7330e89 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 28 May 2024 08:32:42 -0700 Subject: [PATCH 119/124] [BugFix] Fix Embedding Models with TP>1 (#5075) --- vllm/worker/embedding_model_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index ef02de95fc54e..0ba1200696cab 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -79,6 +79,10 @@ def execute_model( execute_model_kwargs.update({"image_input": multi_modal_input}) hidden_states = model_executable(**execute_model_kwargs) + # Only perform pooling in the driver worker. + if not self.is_driver_worker: + return None + return self.model.pooler(hidden_states=hidden_states, pooling_metadata=pooling_metadata) From dd8de11f0a15f9ef48cd1dcac02dd8e2a8bcb494 Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Tue, 28 May 2024 11:03:23 -0500 Subject: [PATCH 120/124] [Kernel][ROCm][AMD] Add fused_moe Triton configs for MI300X (#4951) This PR adds Triton kernel configs for the MoE kernel for MI300X --- ...14336,device_name=AMD_Instinct_MI300X.json | 128 ++++++++++++++++++ ...=1792,device_name=AMD_Instinct_MI300X.json | 110 +++++++++++++++ ...=3584,device_name=AMD_Instinct_MI300X.json | 128 ++++++++++++++++++ ...=7168,device_name=AMD_Instinct_MI300X.json | 128 ++++++++++++++++++ 4 files changed, 494 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..93472eb08a462 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,128 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_stages": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_stages": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_stages": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_stages": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_stages": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_stages": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..5bd9d71e8f9bb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8 + }, + "48": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..02e66280c1a3a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,128 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_stages": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_stages": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_stages": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_stages": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_stages": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_stages": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_stages": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_stages": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..34c3b593d9799 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,128 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_stages": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_stages": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_stages": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_stages": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_stages": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_stages": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_stages": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_stages": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_stages": 0 + } +} From 290f4ada2bf42174a53ae6aab2873e115c8ae11b Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Tue, 28 May 2024 12:29:09 -0500 Subject: [PATCH 121/124] [Docs] Add Dropbox as sponsors (#5089) --- README.md | 1 + docs/source/community/sponsors.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index 627e45d4de0c4..d63819c3815c0 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ vLLM is a community project. Our compute resources for development and testing a - Crusoe Cloud - Databricks - DeepInfra +- Dropbox - Lambda Lab - NVIDIA - Replicate diff --git a/docs/source/community/sponsors.md b/docs/source/community/sponsors.md index 532ce77beb7b8..d167b66267a4d 100644 --- a/docs/source/community/sponsors.md +++ b/docs/source/community/sponsors.md @@ -12,6 +12,7 @@ vLLM is a community project. Our compute resources for development and testing a - Crusoe Cloud - Databricks - DeepInfra +- Dropbox - Lambda Lab - NVIDIA - Replicate From 5ae5ed1e6047d4095149e26526a618be0529a118 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 29 May 2024 04:29:31 +0800 Subject: [PATCH 122/124] [Core] Consolidate prompt arguments to LLM engines (#4328) Co-authored-by: Roger Wang --- .buildkite/test-pipeline.yaml | 9 +- benchmarks/benchmark_latency.py | 11 +- .../{ => dev}/offline_inference/llm.rst | 2 +- .../dev/offline_inference/llm_inputs.rst | 14 + .../dev/offline_inference/offline_index.rst | 8 + .../sampling_params.rst | 0 docs/source/index.rst | 11 +- .../serving/openai_compatible_server.md | 4 +- examples/llava_example.py | 25 +- pyproject.toml | 7 + tests/async_engine/test_async_llm_engine.py | 2 +- tests/async_engine/test_openapi_server_ray.py | 2 +- tests/conftest.py | 23 +- tests/core/test_block_manager.py | 15 +- tests/core/utils.py | 15 +- tests/engine/test_skip_tokenizer_init.py | 2 +- tests/entrypoints/openai/test_serving_chat.py | 4 + tests/entrypoints/test_guided_processors.py | 2 + tests/entrypoints/test_llm_encode.py | 144 +++++++ tests/entrypoints/test_llm_generate.py | 137 +++++- tests/entrypoints/test_openai_server.py | 34 +- .../test_server_oot_registration.py | 11 +- tests/lora/test_long_context.py | 8 +- tests/samplers/test_logits_processor.py | 11 +- tests/samplers/test_seeded_generate.py | 6 +- tests/test_cache_block_hashing.py | 11 +- tests/test_inputs.py | 53 +++ tests/test_utils.py | 63 +++ tests/tokenization/test_detokenize.py | 7 +- tests/utils.py | 14 + vllm/__init__.py | 4 + vllm/engine/async_llm_engine.py | 171 ++++---- vllm/engine/llm_engine.py | 269 ++++++++---- vllm/engine/output_processor/util.py | 10 +- vllm/entrypoints/llm.py | 404 +++++++++++++----- vllm/entrypoints/openai/serving_chat.py | 12 +- vllm/entrypoints/openai/serving_completion.py | 17 +- vllm/entrypoints/openai/serving_embedding.py | 42 +- vllm/entrypoints/openai/serving_engine.py | 6 +- vllm/inputs.py | 130 ++++++ vllm/outputs.py | 42 +- vllm/sequence.py | 38 +- vllm/utils.py | 43 +- 43 files changed, 1404 insertions(+), 439 deletions(-) rename docs/source/{ => dev}/offline_inference/llm.rst (86%) create mode 100644 docs/source/dev/offline_inference/llm_inputs.rst create mode 100644 docs/source/dev/offline_inference/offline_index.rst rename docs/source/{offline_inference => dev}/sampling_params.rst (100%) create mode 100644 tests/entrypoints/test_llm_encode.py create mode 100644 tests/test_inputs.py create mode 100644 tests/test_utils.py create mode 100644 vllm/inputs.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index def8a460e84a7..08e132d0c68bf 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -63,9 +63,9 @@ steps: mirror_hardwares: [amd] commands: - # these tests have to be separated, because each one will allocate all posible GPU memory - - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py - - pytest -v -s entrypoints/test_server_oot_registration.py + - pytest -v -s test_inputs.py + - pytest -v -s entrypoints -m llm + - pytest -v -s entrypoints -m openai - label: Examples Test working_dir: "/vllm-workspace/examples" @@ -110,6 +110,9 @@ steps: mirror_hardwares: [amd] command: pytest -v -s test_logits_processor.py +- label: Utils Test + command: pytest -v -s test_utils.py + - label: Worker Test mirror_hardwares: [amd] command: pytest -v -s worker diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index a9657f7859750..3146fb33cc27e 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -3,13 +3,14 @@ import json import time from pathlib import Path -from typing import Optional +from typing import List, Optional import numpy as np import torch from tqdm import tqdm from vllm import LLM, SamplingParams +from vllm.inputs import PromptStrictInputs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -48,7 +49,9 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_prompt_token_ids = dummy_prompt_token_ids.tolist() + dummy_inputs: List[PromptStrictInputs] = [{ + "prompt_token_ids": batch + } for batch in dummy_prompt_token_ids.tolist()] def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: @@ -59,13 +62,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(prompt_token_ids=dummy_prompt_token_ids, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(prompt_token_ids=dummy_prompt_token_ids, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() diff --git a/docs/source/offline_inference/llm.rst b/docs/source/dev/offline_inference/llm.rst similarity index 86% rename from docs/source/offline_inference/llm.rst rename to docs/source/dev/offline_inference/llm.rst index 1a443ea406994..83ba1b6987c6d 100644 --- a/docs/source/offline_inference/llm.rst +++ b/docs/source/dev/offline_inference/llm.rst @@ -1,5 +1,5 @@ LLM Class -========== +========= .. autoclass:: vllm.LLM :members: diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst new file mode 100644 index 0000000000000..31c3d16a3c8eb --- /dev/null +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -0,0 +1,14 @@ +LLM Inputs +========== + +.. autodata:: vllm.inputs.PromptStrictInputs + +.. autoclass:: vllm.inputs.TextPrompt + :show-inheritance: + :members: + :member-order: bysource + +.. autoclass:: vllm.inputs.TokensPrompt + :show-inheritance: + :members: + :member-order: bysource diff --git a/docs/source/dev/offline_inference/offline_index.rst b/docs/source/dev/offline_inference/offline_index.rst new file mode 100644 index 0000000000000..27dfb0e9df90e --- /dev/null +++ b/docs/source/dev/offline_inference/offline_index.rst @@ -0,0 +1,8 @@ +Offline Inference +================================= + +.. toctree:: + :maxdepth: 1 + + llm + llm_inputs diff --git a/docs/source/offline_inference/sampling_params.rst b/docs/source/dev/sampling_params.rst similarity index 100% rename from docs/source/offline_inference/sampling_params.rst rename to docs/source/dev/sampling_params.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 5db1c9346c45d..5f18fe9ae0a73 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -68,13 +68,6 @@ Documentation getting_started/quickstart getting_started/examples/examples_index -.. toctree:: - :maxdepth: 1 - :caption: Offline Inference - - offline_inference/llm - offline_inference/sampling_params - .. toctree:: :maxdepth: 1 :caption: Serving @@ -108,7 +101,9 @@ Documentation .. toctree:: :maxdepth: 2 :caption: Developer Documentation - + + dev/sampling_params + dev/offline_inference/offline_index dev/engine/engine_index dev/kernel/paged_attention dev/dockerfile/dockerfile diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index a775c6addf1d9..15a8761eb5738 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -48,7 +48,7 @@ completion = client.chat.completions.create( ``` ### Extra Parameters for Chat API -The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. +The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python @@ -65,7 +65,7 @@ The following extra parameters are supported: ``` ### Extra Parameters for Completions API -The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. +The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python diff --git a/examples/llava_example.py b/examples/llava_example.py index 3d22b492654bf..60250c4303fbf 100644 --- a/examples/llava_example.py +++ b/examples/llava_example.py @@ -23,11 +23,15 @@ def run_llava_pixel_values(): "\nUSER: What is the content of this image?\nASSISTANT:") # This should be provided by another online or offline component. - images = torch.load("images/stop_sign_pixel_values.pt") + image = torch.load("images/stop_sign_pixel_values.pt") + + outputs = llm.generate({ + "prompt": + prompt, + "multi_modal_data": + MultiModalData(type=MultiModalData.Type.IMAGE, data=image), + }) - outputs = llm.generate(prompt, - multi_modal_data=MultiModalData( - type=MultiModalData.Type.IMAGE, data=images)) for o in outputs: generated_text = o.outputs[0].text print(generated_text) @@ -46,11 +50,14 @@ def run_llava_image_features(): "\nUSER: What is the content of this image?\nASSISTANT:") # This should be provided by another online or offline component. - images = torch.load("images/stop_sign_image_features.pt") - - outputs = llm.generate(prompt, - multi_modal_data=MultiModalData( - type=MultiModalData.Type.IMAGE, data=images)) + image = torch.load("images/stop_sign_image_features.pt") + + outputs = llm.generate({ + "prompt": + prompt, + "multi_modal_data": + MultiModalData(type=MultiModalData.Type.IMAGE, data=image), + }) for o in outputs: generated_text = o.outputs[0].text print(generated_text) diff --git a/pyproject.toml b/pyproject.toml index 96f78c37cfefb..0e9096fb4c035 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,3 +65,10 @@ skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build" [tool.isort] use_parentheses = true skip_gitignore = true + +[tool.pytest.ini_options] +markers = [ + "skip_global_cleanup", + "llm: run tests for vLLM API only", + "openai: run tests for OpenAI API only", +] diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index b69cdc0a21409..10a46422887e3 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -25,7 +25,7 @@ async def step_async(self): return [RequestOutput( request_id=self.request_id)] if self.request_id else [] - async def encode_request_async(self, *args, **kwargs): + async def process_model_inputs_async(self, *args, **kwargs): pass def generate(self, request_id): diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index ace4c53916c71..7a8d4b3915617 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -29,7 +29,7 @@ def server(): ray.shutdown() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def client(): client = openai.AsyncOpenAI( base_url="http://localhost:8000/v1", diff --git a/tests/conftest.py b/tests/conftest.py index c1a44a606e1bf..af04cfbbb9902 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,7 @@ from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.distributed import destroy_model_parallel +from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.sequence import MultiModalData @@ -402,12 +403,22 @@ def generate( ) -> List[Tuple[List[int], str]]: if images is not None: assert len(prompts) == images.shape[0] - req_outputs = self.model.generate( - prompts, - sampling_params=sampling_params, - multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE, - data=images) - if images is not None else None) + + prompt_inputs: List[PromptInputs] = [] + for i, prompt in enumerate(prompts): + image = None if images is None else images[i:i + 1] + mm_data = None if image is None else MultiModalData( + type=MultiModalData.Type.IMAGE, + data=image, + ) + + prompt_inputs.append({ + "prompt": prompt, + "multi_modal_data": mm_data, + }) + + req_outputs = self.model.generate(prompt_inputs, + sampling_params=sampling_params) outputs = [] for req_output in req_outputs: prompt_str = req_output.prompt diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 22a9f0cf47d32..88cd4f98091f9 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -133,8 +133,11 @@ def test_append_slot_cow(): # Allocate prompt to gpu block. There is one slot left in the block. prompt = Sequence(seq_id=1, - prompt="one two three", - prompt_token_ids=[1, 2, 3], + inputs={ + "prompt": "one two three", + "prompt_token_ids": [1, 2, 3], + "multi_modal_data": None + }, block_size=block_size) # Fork the sequence, such that a COW will be required when we append a new @@ -304,7 +307,13 @@ def test_sliding_window_multi_seq(): assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - parent = Sequence(1, "one two three", [0, 1, 2], block_size) + parent = Sequence(seq_id=1, + inputs={ + "prompt": "one two three", + "prompt_token_ids": [0, 1, 2], + "multi_modal_data": None + }, + block_size=block_size) seq_group = SequenceGroup(request_id="1", seqs=[parent], arrival_time=time.time(), diff --git a/tests/core/utils.py b/tests/core/utils.py index 8fb13177a2d6c..1c5724090b69b 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -21,7 +21,13 @@ def create_dummy_prompt( # and prompt "0 ... block_size". prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) - prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) + prompt = Sequence(int(request_id), + inputs={ + "prompt": prompt_str, + "prompt_token_ids": prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size) seq_group = SequenceGroup(request_id=request_id, seqs=[prompt], arrival_time=time.time(), @@ -51,8 +57,11 @@ def create_seq_group( for seq_id_offset, output_len in enumerate(seq_output_lens): seq = Sequence( seq_id=seq_id_start + seq_id_offset, - prompt="", - prompt_token_ids=prompt_token_ids, + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, block_size=16, ) diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py index baa463a316902..338b208723ba9 100644 --- a/tests/engine/test_skip_tokenizer_init.py +++ b/tests/engine/test_skip_tokenizer_init.py @@ -14,7 +14,7 @@ def test_skip_tokenizer_initialization(model: str): with pytest.raises(ValueError) as err: llm.generate("abc", sampling_params) assert "prompts must be None if" in str(err.value) - outputs = llm.generate(prompt_token_ids=[[1, 2, 3]], + outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, sampling_params=sampling_params) assert len(outputs) > 0 completions = outputs[0].outputs diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 74b49726734b5..c45f02fe564a3 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,11 +1,15 @@ import asyncio from dataclasses import dataclass +import pytest + from vllm.entrypoints.openai.serving_chat import OpenAIServingChat MODEL_NAME = "openai-community/gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" +pytestmark = pytest.mark.openai + @dataclass class MockModelConfig: diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index 41c871ca40bc8..5d4163e96fd87 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -52,6 +52,8 @@ TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") +pytestmark = pytest.mark.openai + def test_guided_logits_processors(): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py new file mode 100644 index 0000000000000..7c3fbe43a8384 --- /dev/null +++ b/tests/entrypoints/test_llm_encode.py @@ -0,0 +1,144 @@ +import weakref +from typing import List + +import pytest + +from vllm import LLM, EmbeddingRequestOutput, PoolingParams + +from ..conftest import cleanup + +MODEL_NAME = "intfloat/e5-mistral-7b-instruct" + +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +TOKEN_IDS = [ + # Using ID={0, 1, 2, 3} results in NaN values, + # so we add this offset of 1000 + [1000], + [1000, 1001], + [1000, 1002, 1001], + [1000, 1003, 1001, 1002], +] + +pytestmark = pytest.mark.llm + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup() + + +def assert_outputs_equal(o1: List[EmbeddingRequestOutput], + o2: List[EmbeddingRequestOutput]): + assert [o.outputs for o in o1] == [o.outputs for o in o2] + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt', PROMPTS) +def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params) + + v2_output = llm.encode(prompt, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) +def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, + prompt_token_ids): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.encode(prompt_token_ids=prompt_token_ids, + pooling_params=pooling_params) + + v2_output = llm.encode({"prompt_token_ids": prompt_token_ids}, + pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params) + + v2_output = llm.encode(PROMPTS, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.encode( + [{ + "prompt": p + } for p in PROMPTS], + pooling_params=pooling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.encode(prompt_token_ids=TOKEN_IDS, + pooling_params=pooling_params) + + v2_output = llm.encode( + [{ + "prompt_token_ids": p + } for p in TOKEN_IDS], + pooling_params=pooling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_multiple_pooling_params(llm: LLM): + pooling_params = [ + PoolingParams(), + PoolingParams(), + PoolingParams(), + PoolingParams(), + ] + + # Multiple PoolingParams should be matched with each prompt + outputs = llm.encode(PROMPTS, pooling_params=pooling_params) + assert len(PROMPTS) == len(outputs) + + # Exception raised, if the size of params does not match the size of prompts + with pytest.raises(ValueError): + outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3]) + + # Single PoolingParams should be applied to every prompt + single_pooling_params = PoolingParams() + outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params) + assert len(PROMPTS) == len(outputs) + + # pooling_params is None, default params should be applied + outputs = llm.encode(PROMPTS, pooling_params=None) + assert len(PROMPTS) == len(outputs) diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index 5e8b7ca4d9977..a00fff91a310e 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -1,21 +1,124 @@ +import weakref +from typing import List + import pytest -from vllm import LLM, SamplingParams +from vllm import LLM, RequestOutput, SamplingParams + +from ..conftest import cleanup + +MODEL_NAME = "facebook/opt-125m" + +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +TOKEN_IDS = [ + [0], + [0, 1], + [0, 2, 1], + [0, 3, 1, 2], +] -def test_multiple_sampling_params(): +pytestmark = pytest.mark.llm - llm = LLM(model="facebook/opt-125m", + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, max_num_batched_tokens=4096, - tensor_parallel_size=1) + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup() + + +def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): + assert [o.outputs for o in o1] == [o.outputs for o in o2] + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt', PROMPTS) +def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.generate(prompts=prompt, + sampling_params=sampling_params) + + v2_output = llm.generate(prompt, sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.generate({"prompt": prompt}, + sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) +def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, + prompt_token_ids): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.generate(prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params) + + v2_output = llm.generate({"prompt_token_ids": prompt_token_ids}, + sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.generate(prompts=PROMPTS, + sampling_params=sampling_params) + + v2_output = llm.generate(PROMPTS, sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.generate( + [{ + "prompt": p + } for p in PROMPTS], + sampling_params=sampling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.generate(prompt_token_ids=TOKEN_IDS, + sampling_params=sampling_params) + + v2_output = llm.generate( + [{ + "prompt_token_ids": p + } for p in TOKEN_IDS], + sampling_params=sampling_params, + ) + assert_outputs_equal(v1_output, v2_output) - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] +@pytest.mark.skip_global_cleanup +def test_multiple_sampling_params(llm: LLM): sampling_params = [ SamplingParams(temperature=0.01, top_p=0.95), SamplingParams(temperature=0.3, top_p=0.95), @@ -24,18 +127,18 @@ def test_multiple_sampling_params(): ] # Multiple SamplingParams should be matched with each prompt - outputs = llm.generate(prompts, sampling_params=sampling_params) - assert len(prompts) == len(outputs) + outputs = llm.generate(PROMPTS, sampling_params=sampling_params) + assert len(PROMPTS) == len(outputs) # Exception raised, if the size of params does not match the size of prompts with pytest.raises(ValueError): - outputs = llm.generate(prompts, sampling_params=sampling_params[:3]) + outputs = llm.generate(PROMPTS, sampling_params=sampling_params[:3]) # Single SamplingParams should be applied to every prompt single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95) - outputs = llm.generate(prompts, sampling_params=single_sampling_params) - assert len(prompts) == len(outputs) + outputs = llm.generate(PROMPTS, sampling_params=single_sampling_params) + assert len(PROMPTS) == len(outputs) # sampling_params is None, default params should be applied - outputs = llm.generate(prompts, sampling_params=None) - assert len(prompts) == len(outputs) \ No newline at end of file + outputs = llm.generate(PROMPTS, sampling_params=None) + assert len(PROMPTS) == len(outputs) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 1b04e3205c4b8..2463ccde2bc8b 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -71,7 +71,7 @@ "Swift", "Kotlin" ] -pytestmark = pytest.mark.asyncio +pytestmark = pytest.mark.openai @pytest.fixture(scope="session") @@ -91,6 +91,8 @@ def server(zephyr_lora_files): "--max-model-len", "8192", "--enforce-eager", + "--gpu-memory-utilization", + "0.75", # lora config below "--enable-lora", "--lora-modules", @@ -118,9 +120,11 @@ def embedding_server(zephyr_lora_files): # use half precision for speed and memory savings in CI environment "--dtype", "bfloat16", + "--enforce-eager", + "--gpu-memory-utilization", + "0.75", "--max-model-len", "8192", - "--enforce-eager", ]) ray.get(server_runner.ready.remote()) yield server_runner @@ -136,6 +140,7 @@ def client(): yield client +@pytest.mark.asyncio async def test_check_models(server, client: openai.AsyncOpenAI): models = await client.models.list() models = models.data @@ -147,6 +152,7 @@ async def test_check_models(server, client: openai.AsyncOpenAI): assert lora_models[1].id == "zephyr-lora2" +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -178,6 +184,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI, completion.choices[0].text) >= 5 +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -199,6 +206,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI, assert choice.logprobs.top_logprobs is None +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -243,6 +251,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, model_name: str): @@ -298,6 +307,7 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -335,6 +345,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI, assert "".join(chunks) == single_output +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -385,6 +396,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI, assert "".join(chunks) == output +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -438,6 +450,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI, assert texts[0] == texts[1] +@pytest.mark.asyncio async def test_logits_bias(server, client: openai.AsyncOpenAI): prompt = "Hello, my name is" max_tokens = 5 @@ -485,6 +498,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): assert first_response != completion.choices[0].text +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_json_completion(server, client: openai.AsyncOpenAI, @@ -507,6 +521,7 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI, jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_json_chat(server, client: openai.AsyncOpenAI, @@ -553,6 +568,7 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI, assert json1["age"] != json2["age"] +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, @@ -573,6 +589,7 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, @@ -610,6 +627,7 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, assert ip1 != ip2 +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, @@ -629,6 +647,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, assert completion.choices[i].text in TEST_CHOICE +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, @@ -667,6 +686,7 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, assert choice1 != choice2 +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, @@ -702,6 +722,7 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, @@ -732,6 +753,7 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, for token, logprob in token_dict.items()) +@pytest.mark.asyncio async def test_response_format_json_object(server, client: openai.AsyncOpenAI): for _ in range(2): resp = await client.chat.completions.create( @@ -749,6 +771,7 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI): assert loaded == {"result": 2}, loaded +@pytest.mark.asyncio async def test_extra_fields(server, client: openai.AsyncOpenAI): with pytest.raises(BadRequestError) as exc_info: await client.chat.completions.create( @@ -764,6 +787,7 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI): assert "extra_forbidden" in exc_info.value.message +@pytest.mark.asyncio async def test_complex_message_content(server, client: openai.AsyncOpenAI): resp = await client.chat.completions.create( model=MODEL_NAME, @@ -783,6 +807,7 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI): assert content == "2" +@pytest.mark.asyncio async def test_custom_role(server, client: openai.AsyncOpenAI): # Not sure how the model handles custom roles so we just check that # both string and complex message content are handled in the same way @@ -813,6 +838,7 @@ async def test_custom_role(server, client: openai.AsyncOpenAI): assert content1 == content2 +@pytest.mark.asyncio async def test_guided_grammar(server, client: openai.AsyncOpenAI): simple_sql_grammar = """ start: select_statement @@ -847,6 +873,7 @@ async def test_guided_grammar(server, client: openai.AsyncOpenAI): assert content.strip() == ground_truth +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -878,6 +905,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, assert len(logprobs.tokens) > 5 +@pytest.mark.asyncio async def test_long_seed(server, client: openai.AsyncOpenAI): for seed in [ torch.iinfo(torch.long).min - 1, @@ -897,6 +925,7 @@ async def test_long_seed(server, client: openai.AsyncOpenAI): or "less_than_equal" in exc_info.value.message) +@pytest.mark.asyncio @pytest.mark.parametrize( "model_name", [EMBEDDING_MODEL_NAME], @@ -935,6 +964,7 @@ async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI, assert embeddings.usage.total_tokens == 5 +@pytest.mark.asyncio @pytest.mark.parametrize( "model_name", [EMBEDDING_MODEL_NAME], diff --git a/tests/entrypoints/test_server_oot_registration.py b/tests/entrypoints/test_server_oot_registration.py index 22e65bf7e7da1..3e55d7f4297fb 100644 --- a/tests/entrypoints/test_server_oot_registration.py +++ b/tests/entrypoints/test_server_oot_registration.py @@ -1,7 +1,7 @@ -import multiprocessing import sys import time +import pytest import torch from openai import OpenAI, OpenAIError @@ -10,6 +10,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.utils import get_open_port +pytestmark = pytest.mark.openai + class MyOPTForCausalLM(OPTForCausalLM): @@ -26,15 +28,16 @@ def server_function(port): # register our dummy model ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) sys.argv = ["placeholder.py"] + \ - ("--model facebook/opt-125m --dtype" - f" float32 --api-key token-abc123 --port {port}").split() + ("--model facebook/opt-125m --gpu-memory-utilization 0.10 " + f"--dtype float32 --api-key token-abc123 --port {port}").split() import runpy runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') def test_oot_registration_for_api_server(): port = get_open_port() - server = multiprocessing.Process(target=server_function, args=(port, )) + ctx = torch.multiprocessing.get_context() + server = ctx.Process(target=server_function, args=(port, )) server.start() client = OpenAI( base_url=f"http://localhost:{port}/v1", diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index 15189f421a539..4361e5452cdff 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -86,20 +86,18 @@ def generate( def batched_generate( - llm, + llm: vllm.LLM, inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]], ): for input in inputs: prompt, sampling_param, lora_req = input - requests_data = llm._validate_and_prepare_requests( + # Add requests to the engine and run the engine + llm._validate_and_add_requests( prompt, sampling_param, lora_request=lora_req, ) - # Add requests to the engine and run the engine - for request_data in requests_data: - llm._add_request(**request_data) outputs = llm._run_engine(use_tqdm=True) return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py index be4c2ea1b7810..0ccbabfff6403 100644 --- a/tests/samplers/test_logits_processor.py +++ b/tests/samplers/test_logits_processor.py @@ -35,28 +35,25 @@ def pick_vllm(token_ids, logits): # test logits_processors when prompt_logprobs is not None vllm_model.model._add_request( - prompt=example_prompts[0], + example_prompts[0], params=params_with_logprobs, - prompt_token_ids=None, ) # test prompt_logprobs is not None vllm_model.model._add_request( - prompt=example_prompts[1], + example_prompts[1], params=SamplingParams( prompt_logprobs=3, max_tokens=max_tokens, ), - prompt_token_ids=None, ) # test grouped requests vllm_model.model._add_request( - prompt=example_prompts[2], + example_prompts[2], params=SamplingParams(max_tokens=max_tokens), - prompt_token_ids=None, ) - outputs = vllm_model.model._run_engine(False) + outputs = vllm_model.model._run_engine(use_tqdm=False) assert outputs[0].outputs[0].text == enforced_answers * repeat_times diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index ce4501bbf71e5..fef5ff3fb9e8e 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -57,11 +57,7 @@ def test_random_sample_with_seed( sampling_params_seed_1, sampling_params_seed_2, ): - llm._add_request( - prompt=prompt, - prompt_token_ids=None, - params=params, - ) + llm._add_request(prompt, params=params) results = llm._run_engine(use_tqdm=False) all_outputs = [[out.token_ids for out in output.outputs] diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index 3b257ac062f56..97864af88e40a 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -70,8 +70,15 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, for prompt in prompts: hashes[-1].append([]) prompt_token_ids = tokenizer.encode(prompt) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - tokenizer.tokenizer.eos_token_id, lora_request) + seq = Sequence(seq_id, + inputs={ + "prompt": prompt, + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, + block_size=block_size, + eos_token_id=tokenizer.tokenizer.eos_token_id, + lora_request=lora_request) num_blocks = len(prompt_token_ids) // block_size for idx in range(num_blocks): diff --git a/tests/test_inputs.py b/tests/test_inputs.py new file mode 100644 index 0000000000000..887c7101decda --- /dev/null +++ b/tests/test_inputs.py @@ -0,0 +1,53 @@ +from typing import List + +import pytest + +from vllm.inputs import parse_and_batch_prompt + +STRING_INPUTS = [ + '', + 'foo', + 'foo bar', + 'foo baz bar', + 'foo bar qux baz', +] + +TOKEN_INPUTS = [ + [-1], + [1], + [1, 2], + [1, 3, 4], + [1, 2, 4, 3], +] + +INPUTS_SLICES = [ + slice(None, None, -1), + slice(None, None, 2), + slice(None, None, -2), +] + + +def test_parse_single_batch_empty(): + with pytest.raises(ValueError, match="at least one prompt"): + parse_and_batch_prompt([]) + + with pytest.raises(ValueError, match="at least one prompt"): + parse_and_batch_prompt([[]]) + + +@pytest.mark.parametrize('string_input', STRING_INPUTS) +def test_parse_single_batch_string_consistent(string_input: str): + assert parse_and_batch_prompt(string_input) \ + == parse_and_batch_prompt([string_input]) + + +@pytest.mark.parametrize('token_input', TOKEN_INPUTS) +def test_parse_single_batch_token_consistent(token_input: List[int]): + assert parse_and_batch_prompt(token_input) \ + == parse_and_batch_prompt([token_input]) + + +@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES) +def test_parse_single_batch_string_slice(inputs_slice: slice): + assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \ + == parse_and_batch_prompt(STRING_INPUTS[inputs_slice]) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000000..54dc5c6f5bfba --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,63 @@ +import pytest + +from vllm.utils import deprecate_kwargs + +from .utils import error_on_warning + + +def test_deprecate_kwargs_always(): + + @deprecate_kwargs("old_arg", is_deprecated=True) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="'old_arg'"): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + +def test_deprecate_kwargs_never(): + + @deprecate_kwargs("old_arg", is_deprecated=False) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with error_on_warning(): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + +def test_deprecate_kwargs_dynamic(): + is_deprecated = True + + @deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="'old_arg'"): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + is_deprecated = False + + with error_on_warning(): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + +def test_deprecate_kwargs_additional_message(): + + @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd") + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="abcd"): + dummy(old_arg=1) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 9bc9becb2a6f1..1d4c74d6bd8da 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -123,8 +123,11 @@ def create_sequence(prompt_token_ids=None): prompt_token_ids = prompt_token_ids or [1] return Sequence( seq_id=0, - prompt="", - prompt_token_ids=prompt_token_ids, + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, block_size=16, ) diff --git a/tests/utils.py b/tests/utils.py index 689d8c8c5ba8a..329842911e159 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,6 +2,8 @@ import subprocess import sys import time +import warnings +from contextlib import contextmanager import ray import requests @@ -87,3 +89,15 @@ def multi_process_tensor_parallel( ray.get(refs) ray.shutdown() + + +@contextmanager +def error_on_warning(): + """ + Within the scope of this context manager, tests will fail if any warning + is emitted. + """ + with warnings.catch_warnings(): + warnings.simplefilter("error") + + yield diff --git a/vllm/__init__.py b/vllm/__init__.py index 74674ca0d12af..a0e154d24087c 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,6 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) @@ -16,6 +17,9 @@ __all__ = [ "LLM", "ModelRegistry", + "PromptStrictInputs", + "TextPrompt", + "TokensPrompt", "SamplingParams", "RequestOutput", "CompletionOutput", diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5a15ed67e3327..d4289c715d9e6 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -12,12 +12,13 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.executor.ray_utils import initialize_ray_cluster, ray +from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) @@ -244,64 +245,69 @@ async def step_async( return request_outputs - async def encode_request_async( + async def process_model_inputs_async( self, - request_id: str, # pylint: disable=unused-argument - prompt: Optional[str], - prompt_token_ids: Optional[List[int]] = None, + request_id: str, + inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, - ): - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = await self.tokenizer.encode_async( + ) -> LLMInputs: + if isinstance(inputs, str): + inputs = {"prompt": inputs} + + if "prompt_token_ids" not in inputs: + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") + + prompt_token_ids = await tokenizer.encode_async( request_id=request_id, - prompt=prompt, + prompt=inputs["prompt"], lora_request=lora_request) - return prompt_token_ids + else: + prompt_token_ids = inputs["prompt_token_ids"] + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) async def add_request_async( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") if arrival_time is None: arrival_time = time.time() - prompt_token_ids = await self.encode_request_async( + + processed_inputs = await self.process_model_inputs_async( + request_id=request_id, inputs=inputs, lora_request=lora_request) + + self._add_processed_request( request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request) - - return self.add_request(request_id, - prompt=prompt, - params=params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time, - lora_request=lora_request, - multi_modal_data=multi_modal_data) + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + ) async def check_health_async(self) -> None: self.model_executor.check_health() class AsyncLLMEngine: - """An asynchronous wrapper for LLMEngine. + """An asynchronous wrapper for :class:`LLMEngine`. - This class is used to wrap the LLMEngine class to make it asynchronous. It - uses asyncio to create a background loop that keeps processing incoming - requests. The LLMEngine is kicked by the generate method when there - are requests in the waiting queue. The generate method yields the outputs - from the LLMEngine to the caller. + This class is used to wrap the :class:`LLMEngine` class to make it + asynchronous. It uses asyncio to create a background loop that keeps + processing incoming requests. The :class:`LLMEngine` is kicked by the + generate method when there are requests in the waiting queue. The generate + method yields the outputs from the :class:`LLMEngine` to the caller. - NOTE: For the comprehensive list of arguments, see `LLMEngine`. + NOTE: For the comprehensive list of arguments, see :class:`LLMEngine`. Args: worker_use_ray: Whether to use Ray for model workers. Required for @@ -315,8 +321,8 @@ class AsyncLLMEngine: being printed in log. start_engine_loop: If True, the background task to run the engine will be automatically started in the generate call. - *args: Arguments for LLMEngine. - *kwargs: Arguments for LLMEngine. + *args: Arguments for :class:`LLMEngine`. + **kwargs: Arguments for :class:`LLMEngine`. """ _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine @@ -526,22 +532,26 @@ async def run_engine_loop(self): async def add_request( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> AsyncStream: if self.log_requests: - shortened_prompt = prompt - shortened_token_ids = prompt_token_ids - if self.max_log_len is not None: + if isinstance(inputs, str): + shortened_prompt = inputs + shortened_token_ids = None + else: + shortened_prompt = inputs.get("prompt") + shortened_token_ids = inputs.get("prompt_token_ids") + + max_log_len = self.max_log_len + if max_log_len is not None: if shortened_prompt is not None: - shortened_prompt = shortened_prompt[:self.max_log_len] + shortened_prompt = shortened_prompt[:max_log_len] if shortened_token_ids is not None: - shortened_token_ids = shortened_token_ids[:self. - max_log_len] + shortened_token_ids = shortened_token_ids[:max_log_len] + logger.info( "Received request %s: prompt: %r, " "params: %s, prompt_token_ids: %s, " @@ -562,39 +572,33 @@ async def add_request( arrival_time = time.time() if self.engine_use_ray: - prompt_token_ids = await ( - self.engine.encode_request_async.remote( # type: ignore + processed_inputs = await self.engine.process_model_inputs_async \ + .remote( # type: ignore request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request)) + inputs=inputs, + lora_request=lora_request) else: - prompt_token_ids = await self.engine.encode_request_async( + processed_inputs = await self.engine.process_model_inputs_async( request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, + inputs=inputs, lora_request=lora_request) stream = self._request_tracker.add_request( request_id, - prompt=prompt, + inputs=processed_inputs, params=params, - prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, lora_request=lora_request, - multi_modal_data=multi_modal_data, ) return stream async def generate( self, - prompt: Optional[str], + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, - prompt_token_ids: Optional[List[int]] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -603,14 +607,12 @@ async def generate( from the LLMEngine to the caller. Args: - prompt: The prompt string. Can be None if prompt_token_ids is - provided. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. - prompt_token_ids: The token IDs of the prompt. If None, we - use the tokenizer to convert the prompts to token IDs. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data per request. Yields: The output `RequestOutput` objects from the LLMEngine @@ -659,24 +661,20 @@ async def generate( >>> # Process and return the final output >>> ... """ - async for output in self.process_request( + async for output in self._process_request( request_id, - prompt, + inputs, sampling_params, - prompt_token_ids, - lora_request, - multi_modal_data, + lora_request=lora_request, ): - yield output + yield LLMEngine.validate_output(output, RequestOutput) async def encode( self, - prompt: Optional[str], + inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, - prompt_token_ids: Optional[List[int]] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None ) -> AsyncIterator[EmbeddingRequestOutput]: """Generate outputs for a request from an embedding model. @@ -685,14 +683,12 @@ async def encode( from the LLMEngine to the caller. Args: - prompt: The prompt string. Can be None if prompt_token_ids is - provided. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. - prompt_token_ids: The token IDs of the prompt. If None, we - use the tokenizer to convert the prompts to token IDs. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data per request. Yields: The output `EmbeddingRequestOutput` objects from the LLMEngine @@ -739,24 +735,21 @@ async def encode( >>> # Process and return the final output >>> ... """ - async for output in self.process_request( + async for output in self._process_request( request_id, - prompt, + inputs, pooling_params, - prompt_token_ids, - lora_request, - multi_modal_data, + lora_request=lora_request, ): - yield output + yield LLMEngine.validate_output(output, EmbeddingRequestOutput) - async def process_request( + async def _process_request( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - prompt_token_ids: Optional[List[int]] = None, + *, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: """Common logic to process requests with SamplingParams or PoolingParams.""" @@ -764,12 +757,10 @@ async def process_request( stream = await self.add_request( request_id, - prompt, + inputs, params, - prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, lora_request=lora_request, - multi_modal_data=multi_modal_data, ) try: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0631c0de76822..08bccf209b7c4 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,5 +1,8 @@ import time -from typing import Iterable, List, Optional, Type, Union +from contextlib import contextmanager +from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional +from typing import Sequence as GenericSequence +from typing import Type, TypeVar, Union from transformers import GenerationConfig, PreTrainedTokenizer @@ -18,6 +21,7 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, @@ -25,8 +29,8 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - MultiModalData, PoolerOutput, SamplerOutput, - Sequence, SequenceGroup, SequenceGroupMetadata, + PoolerOutput, SamplerOutput, Sequence, + SequenceGroup, SequenceGroupMetadata, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, @@ -50,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig): return {} +_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) + + class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -60,11 +67,11 @@ class LLMEngine: iteration-level scheduling and efficient memory management to maximize the serving throughput. - The `LLM` class wraps this class for offline batched inference and the - `AsyncLLMEngine` class wraps this class for online serving. + The :class:`~vllm.LLM` class wraps this class for offline batched inference + and the :class:`AsyncLLMEngine` class wraps this class for online serving. - NOTE: The config arguments are derived from the `EngineArgs` class. For the - comprehensive list of arguments, see `EngineArgs`. + NOTE: The config arguments are derived from the :class:`~vllm.EngineArgs` + class. For the comprehensive list of arguments, see :ref:`engine_args`. Args: model_config: The configuration related to the LLM model. @@ -81,9 +88,60 @@ class LLMEngine: executor_class: The model executor class for managing distributed execution. log_stats: Whether to log statistics. - usage_context: Specified entry point, used for usage info collection + usage_context: Specified entry point, used for usage info collection. """ + DO_VALIDATE_OUTPUT: ClassVar[bool] = False + """A flag to toggle whether to validate the type of request output.""" + + @classmethod + @contextmanager + def enable_output_validation(cls): + cls.DO_VALIDATE_OUTPUT = True + + yield + + cls.DO_VALIDATE_OUTPUT = False + + @classmethod + def validate_output( + cls, + output: object, + output_type: Type[_O], + ) -> _O: + do_validate = cls.DO_VALIDATE_OUTPUT + + if ((TYPE_CHECKING or do_validate) + and not isinstance(output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + return output + + @classmethod + def validate_outputs( + cls, + outputs: GenericSequence[object], + output_type: Type[_O], + ) -> List[_O]: + do_validate = cls.DO_VALIDATE_OUTPUT + + outputs_: List[_O] + if TYPE_CHECKING or do_validate: + outputs_ = [] + for output in outputs: + if not isinstance(output, output_type): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + outputs_.append(output) + else: + outputs_ = outputs + + return outputs_ + + tokenizer: Optional[BaseTokenizerGroup] + def __init__( self, model_config: ModelConfig, @@ -151,12 +209,11 @@ def __init__( self.log_stats = log_stats if not self.model_config.skip_tokenizer_init: - self.tokenizer: BaseTokenizerGroup - self._init_tokenizer() + self.tokenizer = self._init_tokenizer() self.detokenizer = Detokenizer(self.tokenizer) else: - self.detokenizer = None self.tokenizer = None + self.detokenizer = None self.seq_counter = Counter() self.generation_config_fields = _load_generation_config_dict( @@ -318,14 +375,26 @@ def __del__(self): if model_executor := getattr(self, "model_executor", None): model_executor.shutdown() + MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because " + "skip_tokenizer_init is True") + + def get_tokenizer_group( + self, + fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup: + if self.tokenizer is None: + raise ValueError(fail_msg) + + return self.tokenizer + def get_tokenizer(self) -> "PreTrainedTokenizer": - return self.tokenizer.get_lora_tokenizer(None) + return self.get_tokenizer_group().get_lora_tokenizer(None) def get_tokenizer_for_seq(self, sequence: Sequence) -> "PreTrainedTokenizer": - return self.tokenizer.get_lora_tokenizer(sequence.lora_request) + return self.get_tokenizer_group().get_lora_tokenizer( + sequence.lora_request) - def _init_tokenizer(self, **tokenizer_init_kwargs): + def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: init_kwargs = dict( tokenizer_id=self.model_config.tokenizer, enable_lora=bool(self.lora_config), @@ -335,8 +404,9 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) - self.tokenizer = get_tokenizer_group( - self.parallel_config.tokenizer_pool_config, **init_kwargs) + + return get_tokenizer_group(self.parallel_config.tokenizer_pool_config, + **init_kwargs) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -346,29 +416,85 @@ def _verify_args(self) -> None: self.lora_config.verify_with_scheduler_config( self.scheduler_config) - def encode_request( + def _get_eos_token_id( + self, lora_request: Optional[LoRARequest]) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for EOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + + def _add_processed_request( self, - request_id: str, # pylint: disable=unused-argument - prompt: Optional[str], - prompt_token_ids: Optional[List[int]] = None, + request_id: str, + processed_inputs: LLMInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: float, + lora_request: Optional[LoRARequest], + ) -> None: + # Create the sequences. + block_size = self.cache_config.block_size + seq_id = next(self.seq_counter) + eos_token_id = self._get_eos_token_id(lora_request) + + seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, + lora_request) + + # Create a SequenceGroup based on SamplingParams or PoolingParams + if isinstance(params, SamplingParams): + seq_group = self._create_sequence_group_with_sampling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + ) + elif isinstance(params, PoolingParams): + seq_group = self._create_sequence_group_with_pooling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + ) + else: + raise ValueError( + "Either SamplingParams or PoolingParams must be provided.") + + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) + + def process_model_inputs( + self, + request_id: str, + inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, - ): - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = self.tokenizer.encode(request_id=request_id, - prompt=prompt, - lora_request=lora_request) - return prompt_token_ids + ) -> LLMInputs: + if isinstance(inputs, str): + inputs = {"prompt": inputs} + + if "prompt_token_ids" not in inputs: + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") + + prompt_token_ids = tokenizer.encode(request_id=request_id, + prompt=inputs["prompt"], + lora_request=lora_request) + else: + prompt_token_ids = inputs["prompt_token_ids"] + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) def add_request( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: """Add a request to the engine's request pool. @@ -378,15 +504,14 @@ def add_request( Args: request_id: The unique ID of the request. - prompt: The prompt string. Can be None if prompt_token_ids is - provided. - params: Parameters for sampling or pooling. SamplingParams - for text generation. PoolingParams for pooling. - prompt_token_ids: The token IDs of the prompt. If None, we - use the tokenizer to convert the prompts to token IDs. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. + params: Parameters for sampling or pooling. + :class:`~vllm.SamplingParams` for text generation. + :class:`~vllm.PoolingParams` for pooling. arrival_time: The arrival time of the request. If None, we use the current monotonic time. - multi_modal_data: Multi modal data per request. Details: - Set arrival_time to the current time if it is None. @@ -417,59 +542,26 @@ def add_request( "not enabled!") if arrival_time is None: arrival_time = time.time() - prompt_token_ids = self.encode_request( - request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request) - - # Create the sequences. - block_size = self.cache_config.block_size - seq_id = next(self.seq_counter) - eos_token_id = None - if self.tokenizer: - eos_token_id = self.tokenizer.get_lora_tokenizer( - lora_request).eos_token_id - else: - logger.warning("Use None for EOS token id because tokenizer is " - "not initialized") - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - eos_token_id, lora_request) - # Create a SequenceGroup based on SamplingParams or PoolingParams - if isinstance(params, SamplingParams): - seq_group = self._create_sequence_group_with_sampling( - request_id, - seq, - params, - arrival_time, - lora_request, - multi_modal_data, - ) - elif isinstance(params, PoolingParams): - seq_group = self._create_sequence_group_with_pooling( - request_id, - seq, - params, - arrival_time, - lora_request, - multi_modal_data, - ) - else: - raise ValueError( - "Either SamplingParams or PoolingParams must be provided.") + processed_inputs = self.process_model_inputs(request_id=request_id, + inputs=inputs, + lora_request=lora_request) - # Add the sequence group to the scheduler. - self.scheduler.add_seq_group(seq_group) + self._add_processed_request( + request_id=request_id, + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + ) def _create_sequence_group_with_sampling( self, request_id: str, seq: Sequence, sampling_params: SamplingParams, - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, + arrival_time: float, + lora_request: Optional[LoRARequest], ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -495,8 +587,7 @@ def _create_sequence_group_with_sampling( seqs=[seq], arrival_time=arrival_time, sampling_params=sampling_params, - lora_request=lora_request, - multi_modal_data=multi_modal_data) + lora_request=lora_request) return seq_group @@ -505,9 +596,8 @@ def _create_sequence_group_with_pooling( request_id: str, seq: Sequence, pooling_params: PoolingParams, - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, + arrival_time: float, + lora_request: Optional[LoRARequest], ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler @@ -517,7 +607,6 @@ def _create_sequence_group_with_pooling( seqs=[seq], arrival_time=arrival_time, lora_request=lora_request, - multi_modal_data=multi_modal_data, pooling_params=pooling_params) return seq_group @@ -570,7 +659,7 @@ def _process_sequence_group_outputs( def _process_model_outputs( self, - output: List[Union[SamplerOutput, PoolerOutput]], + output: GenericSequence[Union[SamplerOutput, PoolerOutput]], scheduled_seq_groups: List[ScheduledSequenceGroup], ignored_seq_groups: List[SequenceGroup], seq_group_metadata_list: List[SequenceGroupMetadata], @@ -585,7 +674,7 @@ def _process_model_outputs( # Organize outputs by [sequence group][step] instead of # [step][sequence group]. output_by_sequence_group = create_output_by_sequence_group( - sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups)) + output, num_seq_groups=len(scheduled_seq_groups)) # Update the scheduled sequence groups with the model outputs. for scheduled_seq_group, outputs, seq_group_meta in zip( diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py index 9816e966c1e36..57cc33d911183 100644 --- a/vllm/engine/output_processor/util.py +++ b/vllm/engine/output_processor/util.py @@ -1,18 +1,20 @@ from typing import List +from typing import Sequence as GenericSequence +from typing import Union -from vllm.sequence import SamplerOutput, SequenceGroupOutput +from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput def create_output_by_sequence_group( - sampler_outputs: List[SamplerOutput], + outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]], num_seq_groups: int) -> List[List[SequenceGroupOutput]]: """Helper method which transforms a 2d list organized by [step][sequence group] into [sequence group][step]. """ - output_by_sequence_group: List[List[SamplerOutput]] = [ + output_by_sequence_group: List[List[SequenceGroupOutput]] = [ [] for _ in range(num_seq_groups) ] - for step in sampler_outputs: + for step in outputs: for i, sequence_group_output in enumerate(step): output_by_sequence_group[i].append(sequence_group_output) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 25f4428100b27..9759d05577796 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,11 +1,14 @@ -from typing import List, Optional, Union +from contextlib import contextmanager +from typing import ClassVar, List, Optional, Sequence, Union, cast, overload -import torch from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine +from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt, + TextTokensPrompt, TokensPrompt, + parse_and_batch_prompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -13,7 +16,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import MultiModalData from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter +from vllm.utils import Counter, deprecate_kwargs logger = init_logger(__name__) @@ -28,8 +31,10 @@ class LLM: mechanism and efficient memory management. NOTE: This class is intended to be used for offline inference. For online - serving, use the `AsyncLLMEngine` class instead. - NOTE: For the comprehensive list of arguments, see `EngineArgs`. + serving, use the :class:`~vllm.AsyncLLMEngine` class instead. + + NOTE: For the comprehensive list of arguments, see + :class:`~vllm.EngineArgs`. Args: model: The name or path of a HuggingFace Transformers model. @@ -81,6 +86,18 @@ class LLM: disable_custom_all_reduce: See ParallelConfig """ + DEPRECATE_LEGACY: ClassVar[bool] = False + """A flag to toggle whether to deprecate the legacy generate/encode API.""" + + @classmethod + @contextmanager + def deprecate_legacy_api(cls): + cls.DEPRECATE_LEGACY = True + + yield + + cls.DEPRECATE_LEGACY = False + def __init__( self, model: str, @@ -138,15 +155,101 @@ def set_tokenizer( ) -> None: self.llm_engine.tokenizer.tokenizer = tokenizer + @overload # LEGACY: single (prompt + optional token ids) + def generate( + self, + prompts: str, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + prompt_token_ids: Optional[List[int]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: multi (prompt + optional token ids) def generate( self, - prompts: Optional[Union[str, List[str]]] = None, + prompts: List[str], sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: single (token ids + optional prompt) + def generate( + self, + prompts: Optional[str] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + *, + prompt_token_ids: List[int], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: multi (token ids + optional prompt) + def generate( + self, + prompts: Optional[List[str]] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + *, + prompt_token_ids: List[List[int]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: single or multi token ids [pos-only] + def generate( + self, + prompts: None, + sampling_params: None, + prompt_token_ids: Union[List[int], List[List[int]]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload + def generate( + self, + inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + /, # We may enable `inputs` keyword after removing the old API + *, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + ) -> List[RequestOutput]: + ... + + @deprecate_kwargs("prompts", + "prompt_token_ids", + "multi_modal_data", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter " + "instead.") + def generate( + self, + prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + Optional[Union[str, List[str]]]] = None, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -155,49 +258,138 @@ def generate( into a single list and pass it to this method. Args: - prompts: A list of prompts to generate completions for. + inputs: A list of inputs to generate completions for. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. When it is a list, the list must have the same length as the prompts and it is paired one by one with the prompt. - prompt_token_ids: A list of token IDs for the prompts. If None, we - use the tokenizer to convert the prompts to token IDs. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data. Returns: A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ + if prompt_token_ids is not None or multi_modal_data is not None: + inputs = self._convert_v1_inputs( + prompts=cast(Optional[Union[str, List[str]]], prompts), + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + ) + else: + inputs = cast( + Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + prompts) + if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() - requests_data = self._validate_and_prepare_requests( - prompts, - sampling_params, - prompt_token_ids, - lora_request, - multi_modal_data, + self._validate_and_add_requests( + inputs=inputs, + params=sampling_params, + lora_request=lora_request, ) - # Add requests to the engine and run the engine - for request_data in requests_data: - self._add_request(**request_data) + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, RequestOutput) - return self._run_engine(use_tqdm) + @overload # LEGACY: single (prompt + optional token ids) + def encode( + self, + prompts: str, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[List[int]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + @overload # LEGACY: multi (prompt + optional token ids) def encode( self, - prompts: Optional[Union[str, List[str]]] = None, + prompts: List[str], pooling_params: Optional[Union[PoolingParams, - List[PoolingParams]]] = None, + Sequence[PoolingParams]]] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: single (token ids + optional prompt) + def encode( + self, + prompts: Optional[str] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + prompt_token_ids: List[int], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: multi (token ids + optional prompt) + def encode( + self, + prompts: Optional[List[str]] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + prompt_token_ids: List[List[int]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: single or multi token ids [pos-only] + def encode( + self, + prompts: None, + pooling_params: None, + prompt_token_ids: Union[List[int], List[List[int]]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload + def encode( + self, + inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + /, # We may enable `inputs` keyword after removing the old API + *, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @deprecate_kwargs("prompts", + "prompt_token_ids", + "multi_modal_data", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter " + "instead.") + def encode( + self, + prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + Optional[Union[str, List[str]]]] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> List[EmbeddingRequestOutput]: """Generates the completions for the input prompts. @@ -206,124 +398,133 @@ def encode( into a single list and pass it to this method. Args: - prompts: A list of prompts to generate completions for. + inputs: The inputs to the LLM. You may pass a sequence of inputs for + batch inference. See :class:`~vllm.inputs.PromptStrictInputs` + for more details about the format of each input. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. - prompt_token_ids: A list of token IDs for the prompts. If None, we - use the tokenizer to convert the prompts to token IDs. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data. Returns: A list of `EmbeddingRequestOutput` objects containing the generated embeddings in the same order as the input prompts. """ + if prompt_token_ids is not None or multi_modal_data is not None: + inputs = self._convert_v1_inputs( + prompts=cast(Optional[Union[str, List[str]]], prompts), + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + ) + else: + inputs = cast( + Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + prompts) + if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() - requests_data = self._validate_and_prepare_requests( - prompts, - pooling_params, - prompt_token_ids, - lora_request, - multi_modal_data, + self._validate_and_add_requests( + inputs=inputs, + params=pooling_params, + lora_request=lora_request, ) - # Add requests to the engine and run the engine - for request_data in requests_data: - self._add_request(**request_data) + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput) - return self._run_engine(use_tqdm) - - def _validate_and_prepare_requests( + # LEGACY + def _convert_v1_inputs( self, prompts: Optional[Union[str, List[str]]], - params: Union[Union[SamplingParams, PoolingParams], - List[Union[SamplingParams, - PoolingParams]]], # Unified parameter - prompt_token_ids: Optional[List[List[int]]] = None, - lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, - ) -> List[dict]: - """Validates and prepares request data for adding to the engine. + prompt_token_ids: Optional[Union[List[int], List[List[int]]]], + multi_modal_data: Optional[MultiModalData], + ): + # skip_tokenizer_init is now checked in engine - Ensures prompts and token IDs are consistent, and returns a list of - dictionaries with request data for further processing. - """ - if prompts is None and prompt_token_ids is None: - raise ValueError("Either prompts or prompt_token_ids must be " - "provided.") - if self.llm_engine.model_config.skip_tokenizer_init \ - and prompts is not None: - raise ValueError("prompts must be None if skip_tokenizer_init " - "is True") - if isinstance(prompts, str): - # Convert a single prompt to a list. - prompts = [prompts] - if (prompts is not None and prompt_token_ids is not None - and len(prompts) != len(prompt_token_ids)): - raise ValueError("The lengths of prompts and prompt_token_ids " - "must be the same.") + if prompts is not None: + prompts = [p["content"] for p in parse_and_batch_prompt(prompts)] + if prompt_token_ids is not None: + prompt_token_ids = [ + p["content"] for p in parse_and_batch_prompt(prompt_token_ids) + ] + num_requests = None if prompts is not None: num_requests = len(prompts) - else: - assert prompt_token_ids is not None + if prompt_token_ids is not None: + if (num_requests is not None + and num_requests != len(prompt_token_ids)): + raise ValueError("The lengths of prompts and prompt_token_ids " + "must be the same.") + num_requests = len(prompt_token_ids) + if num_requests is None: + raise ValueError("Either prompts or prompt_token_ids must be " + "provided.") + + inputs: List[PromptInputs] = [] + for i in range(num_requests): + if prompts is not None: + if prompt_token_ids is not None: + item = TextTokensPrompt( + prompt=prompts[i], + prompt_token_ids=prompt_token_ids[i]) + else: + item = TextPrompt(prompt=prompts[i]) + else: + if prompt_token_ids is not None: + item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) + else: + raise AssertionError + + if multi_modal_data is not None: + item["multi_modal_data"] = multi_modal_data + + inputs.append(item) + + return inputs + + def _validate_and_add_requests( + self, + inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, + Sequence[PoolingParams]], + lora_request: Optional[LoRARequest], + ) -> None: + if isinstance(inputs, (str, dict)): + # Convert a single prompt to a list. + inputs = [inputs] + + num_requests = len(inputs) if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") - if multi_modal_data: - multi_modal_data.data = multi_modal_data.data.to(torch.float16) # Add requests to the engine. - requests_data = [] - for i in range(num_requests): - prompt = prompts[i] if prompts is not None else None - token_ids = None if prompt_token_ids is None else prompt_token_ids[ - i] - - multi_modal_item = MultiModalData( - type=multi_modal_data.type, - data=multi_modal_data.data[i].unsqueeze(0), - ) if multi_modal_data else None - - requests_data.append({ - "prompt": - prompt, - "params": - params[i] if isinstance(params, list) else params, - "prompt_token_ids": - token_ids, - "lora_request": - lora_request, - "multi_modal_data": - multi_modal_item, - }) - - return requests_data + for i, request_inputs in enumerate(inputs): + self._add_request( + request_inputs, + params[i] if isinstance(params, Sequence) else params, + lora_request=lora_request, + ) def _add_request( self, - prompt: Optional[str], + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - prompt_token_ids: Optional[List[int]], lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id, - prompt, + inputs, params, - prompt_token_ids, - lora_request=lora_request, - multi_modal_data=multi_modal_data) + lora_request=lora_request) def _run_engine( - self, use_tqdm: bool + self, *, use_tqdm: bool ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Initialize tqdm. if use_tqdm: @@ -355,5 +556,4 @@ def _run_engine( # Sort the outputs by request ID. # This is necessary because some requests may be finished earlier than # its previous requests. - outputs = sorted(outputs, key=lambda x: int(x.request_id)) - return outputs + return sorted(outputs, key=lambda x: int(x.request_id)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 7e179362eef8a..33daabd881df0 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -176,9 +176,15 @@ async def create_chat_completion( except ValueError as e: return self.create_error_response(str(e)) - result_generator = self.engine.generate(prompt_text, sampling_params, - request_id, prompt_ids, - lora_request) + result_generator = self.engine.generate( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + sampling_params, + request_id, + lora_request, + ) # Streaming response if request.stream: return self.chat_completion_stream_generator( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 158d8ed7fbbf5..d1812c8f44f41 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -119,12 +119,17 @@ async def create_completion(self, request: CompletionRequest, truncate_prompt_tokens) prompt_ids, prompt_text = prompt_formats - generators.append( - self.engine.generate(prompt_text, - sampling_params, - f"{request_id}-{i}", - prompt_token_ids=prompt_ids, - lora_request=lora_request)) + generator = self.engine.generate( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + sampling_params, + f"{request_id}-{i}", + lora_request=lora_request, + ) + + generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 7a57be0c88915..5a3448de3d7a4 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,5 +1,5 @@ import time -from typing import AsyncIterator, List, Tuple +from typing import AsyncIterator, List, Optional, Tuple from fastapi import Request @@ -100,11 +100,16 @@ async def create_embedding(self, request: EmbeddingRequest, prompt_ids, prompt_text = prompt_formats - generators.append( - self.engine.generate(prompt_text, - pooling_params, - f"{request_id}-{i}", - prompt_token_ids=prompt_ids)) + generator = self.engine.encode( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + pooling_params, + f"{request_id}-{i}", + ) + + generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -113,16 +118,21 @@ async def create_embedding(self, request: EmbeddingRequest, int, EmbeddingRequestOutput]] = merge_async_iterators(*generators) # Non-streaming response - final_res_batch: EmbeddingRequestOutput = [None] * len(prompts) - async for i, res in result_generator: - if await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await self.engine.abort(f"{request_id}-{i}") - # TODO: Use a vllm-specific Validation Error - return self.create_error_response("Client disconnected") - final_res_batch[i] = res - response = request_output_to_embedding_response( - final_res_batch, request_id, created_time, model_name) + final_res_batch: List[Optional[EmbeddingRequestOutput]] + final_res_batch = [None] * len(prompts) + try: + async for i, res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(f"{request_id}-{i}") + # TODO: Use a vllm-specific Validation Error + return self.create_error_response("Client disconnected") + final_res_batch[i] = res + response = request_output_to_embedding_response( + final_res_batch, request_id, created_time, model_name) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) return response diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 0df0223b9dbb2..708b0dad102c4 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -143,7 +143,8 @@ def create_streaming_error_response( return json_str async def _check_model( - self, request: Union[CompletionRequest, ChatCompletionRequest] + self, request: Union[CompletionRequest, ChatCompletionRequest, + EmbeddingRequest] ) -> Optional[ErrorResponse]: if request.model in self.served_model_names: return None @@ -155,7 +156,8 @@ async def _check_model( status_code=HTTPStatus.NOT_FOUND) def _maybe_get_lora( - self, request: Union[CompletionRequest, ChatCompletionRequest] + self, request: Union[CompletionRequest, ChatCompletionRequest, + EmbeddingRequest] ) -> Optional[LoRARequest]: if request.model in self.served_model_names: return None diff --git a/vllm/inputs.py b/vllm/inputs.py new file mode 100644 index 0000000000000..f5d99b1b66b70 --- /dev/null +++ b/vllm/inputs.py @@ -0,0 +1,130 @@ +from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, + TypedDict, Union, cast, overload) + +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from vllm.sequence import MultiModalData + + +class ParsedText(TypedDict): + content: str + is_tokens: Literal[False] + + +class ParsedTokens(TypedDict): + content: List[int] + is_tokens: Literal[True] + + +# https://github.com/vllm-project/vllm/pull/4028 +@overload +def parse_and_batch_prompt( + prompt: Union[str, List[str]]) -> Sequence[ParsedText]: + ... + + +@overload +def parse_and_batch_prompt( + prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]: + ... + + +def parse_and_batch_prompt( + prompt: Union[str, List[str], List[int], List[List[int]]], +) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: + if isinstance(prompt, str): + # case 1: a string + return [ParsedText(content=prompt, is_tokens=False)] + + if isinstance(prompt, list): + if len(prompt) == 0: + raise ValueError("please provide at least one prompt") + + if isinstance(prompt[0], str): + # case 2: array of strings + return [ + ParsedText(content=elem, is_tokens=False) + for elem in cast(List[str], prompt) + ] + if isinstance(prompt[0], int): + # case 3: array of tokens + elem = cast(List[int], prompt) + return [ParsedTokens(content=elem, is_tokens=True)] + if isinstance(prompt[0], list): + if len(prompt[0]) == 0: + raise ValueError("please provide at least one prompt") + + if isinstance(prompt[0][0], int): + # case 4: array of token arrays + return [ + ParsedTokens(content=elem, is_tokens=True) + for elem in cast(List[List[int]], prompt) + ] + + raise ValueError("prompt must be a string, array of strings, " + "array of tokens, or array of token arrays") + + +class TextPrompt(TypedDict): + """Schema for a text prompt.""" + + prompt: str + """The input text to be tokenized before passing to the model.""" + + multi_modal_data: NotRequired["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +class TokensPrompt(TypedDict): + """Schema for a tokenized prompt.""" + + prompt_token_ids: List[int] + """A list of token IDs to pass to the model.""" + + multi_modal_data: NotRequired["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +class TextTokensPrompt(TypedDict): + """It is assumed that :attr:`prompt` is consistent with + :attr:`prompt_token_ids`. This is currently used in + :class:`AsyncLLMEngine` for logging both the text and token IDs.""" + + prompt: str + """The prompt text.""" + + prompt_token_ids: List[int] + """The token IDs of the prompt. If None, we use the + tokenizer to convert the prompts to token IDs.""" + + multi_modal_data: NotRequired["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +PromptStrictInputs = Union[str, TextPrompt, TokensPrompt] +""" +The inputs to the LLM, which can take one of the following forms: + +- A text prompt (:class:`str` or :class:`TextPrompt`) +- A tokenized prompt (:class:`TokensPrompt`) +""" + +PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt] +"""Same as :const:`PromptStrictInputs` but additionally accepts +:class:`TextTokensPrompt`.""" + + +class LLMInputs(TypedDict): + prompt_token_ids: List[int] + prompt: Optional[str] + multi_modal_data: Optional["MultiModalData"] diff --git a/vllm/outputs.py b/vllm/outputs.py index f9bce9e683f22..49f526b5f9300 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,4 +1,5 @@ import time +from dataclasses import dataclass from typing import List, Optional, Union from vllm.lora.request import LoRARequest @@ -6,6 +7,7 @@ SequenceGroup, SequenceStatus) +@dataclass class CompletionOutput: """The output data of one completion output of a request. @@ -24,25 +26,14 @@ class CompletionOutput: lora_request: The LoRA request that was used to generate the output. """ - def __init__( - self, - index: int, - text: str, - token_ids: List[int], - cumulative_logprob: float, - logprobs: Optional[SampleLogprobs], - finish_reason: Optional[str] = None, - stop_reason: Union[int, str, None] = None, - lora_request: Optional[LoRARequest] = None, - ) -> None: - self.index = index - self.text = text - self.token_ids = token_ids - self.cumulative_logprob = cumulative_logprob - self.logprobs = logprobs - self.finish_reason = finish_reason - self.stop_reason = stop_reason - self.lora_request = lora_request + index: int + text: str + token_ids: List[int] + cumulative_logprob: float + logprobs: Optional[SampleLogprobs] + finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + lora_request: Optional[LoRARequest] = None def finished(self) -> bool: return self.finish_reason is not None @@ -57,6 +48,7 @@ def __repr__(self) -> str: f"stop_reason={self.stop_reason})") +@dataclass class EmbeddingOutput: """The output data of one completion output of a request. @@ -65,15 +57,11 @@ class EmbeddingOutput: length of vector depends on the model as listed in the embedding guide. """ - def __init__( - self, - embedding: List[float], - ) -> None: - self.embedding = embedding + embedding: List[float] def __repr__(self) -> str: return (f"EmbeddingOutput(" - f"embedding={len(self.embedding)}") + f"embedding={len(self.embedding)})") class RequestOutput: @@ -93,7 +81,7 @@ class RequestOutput: def __init__( self, request_id: str, - prompt: str, + prompt: Optional[str], prompt_token_ids: List[int], prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], @@ -183,7 +171,7 @@ class EmbeddingRequestOutput: finished (bool): A flag indicating whether the embedding is completed. """ - def __init__(self, request_id: str, outputs: 'EmbeddingOutput', + def __init__(self, request_id: str, outputs: "EmbeddingOutput", prompt_token_ids: List[int], finished: bool): self.request_id = request_id self.prompt_token_ids = prompt_token_ids diff --git a/vllm/sequence.py b/vllm/sequence.py index aa759448d82b1..f8e9da6c7965a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from vllm.block import LogicalTokenBlock +from vllm.inputs import LLMInputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -210,8 +211,7 @@ class Sequence: Args: seq_id: The ID of the sequence. - prompt: The prompt of the sequence. - prompt_token_ids: The token IDs of the prompt. + inputs: The inputs of the sequence. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. lora_request: LoRA request. @@ -220,25 +220,24 @@ class Sequence: def __init__( self, seq_id: int, - prompt: str, - prompt_token_ids: List[int], + inputs: LLMInputs, block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, ) -> None: self.seq_id = seq_id - self.prompt = prompt + self.inputs = inputs self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request - self.data: SequenceData = SequenceData(prompt_token_ids) + self.data = SequenceData(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_text = "" self.logical_token_blocks: List[LogicalTokenBlock] = [] # Initialize the logical token blocks with the prompt token ids. - self._append_tokens_to_blocks(prompt_token_ids) + self._append_tokens_to_blocks(self.prompt_token_ids) self.status = SequenceStatus.WAITING self.stop_reason: Union[int, str, None] = None @@ -248,6 +247,18 @@ def __init__( # Input + output tokens self.tokens: Optional[List[str]] = None + @property + def prompt(self) -> Optional[str]: + return self.inputs["prompt"] + + @property + def prompt_token_ids(self) -> List[int]: + return self.inputs["prompt_token_ids"] + + @property + def multi_modal_data(self) -> Optional["MultiModalData"]: + return self.inputs["multi_modal_data"] + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -415,7 +426,6 @@ class SequenceGroup: sampling_params: The sampling parameters used to generate the outputs. arrival_time: The arrival time of the request. lora_request: LoRA request. - multi_modal_data: Multi modal data associated with the request. embeddings: The embeddings vectors of the prompt of the sequence group for an embedding model. pooling_params: The pooling parameters used to generate the pooling @@ -429,7 +439,6 @@ def __init__( arrival_time: float, sampling_params: Optional[SamplingParams] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, embeddings: Optional[List[float]] = None, pooling_params: Optional[PoolingParams] = None, ) -> None: @@ -444,12 +453,11 @@ def __init__( self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() - self.multi_modal_data = multi_modal_data self.embeddings = embeddings self.pooling_params = pooling_params @property - def prompt(self) -> str: + def prompt(self) -> Optional[str]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. return next(iter(self.seqs_dict.values())).prompt @@ -458,7 +466,13 @@ def prompt(self) -> str: def prompt_token_ids(self) -> List[int]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. - return next(iter(self.seqs_dict.values())).data.prompt_token_ids + return next(iter(self.seqs_dict.values())).prompt_token_ids + + @property + def multi_modal_data(self) -> Optional[MultiModalData]: + # All sequences in the group should have the same multi-modal data. + # We use the multi-modal data of an arbitrary sequence. + return next(iter(self.seqs_dict.values())).multi_modal_data @property def lora_int_id(self) -> int: diff --git a/vllm/utils.py b/vllm/utils.py index 4cb9d905097bf..c8bc54dab41b3 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -11,7 +11,7 @@ import uuid import warnings from collections import defaultdict -from functools import lru_cache, partial +from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Hashable, List, Optional, OrderedDict, Tuple, TypeVar, @@ -658,3 +658,44 @@ def enable_trace_function_call_for_thread() -> None: filename) os.makedirs(os.path.dirname(log_path), exist_ok=True) enable_trace_function_call(log_path) + + +def identity(value: T) -> T: + return value + + +F = TypeVar('F', bound=Callable[..., Any]) + + +def deprecate_kwargs( + *kws: str, + is_deprecated: Union[bool, Callable[[], bool]] = True, + additional_message: Optional[str] = None) -> Callable[[F], F]: + deprecated_kws = set(kws) + + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + + def wrapper(fn: F) -> F: + + @wraps(fn) + def inner(*args, **kwargs): + if is_deprecated(): + deprecated_kwargs = kwargs.keys() & deprecated_kws + if deprecated_kwargs: + msg = ( + f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update.") + if additional_message is not None: + msg += f" {additional_message}" + + warnings.warn( + DeprecationWarning(msg), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper From dfba529b4024fba9ce1346467318b35e8f2fa9d9 Mon Sep 17 00:00:00 2001 From: Junichi Sato Date: Wed, 29 May 2024 09:15:35 +0900 Subject: [PATCH 123/124] [Bugfix] Remove the last EOS token unless explicitly specified (#5077) --- .../output_processor/test_stop_checker.py | 86 +++++++++++++++++++ vllm/engine/output_processor/stop_checker.py | 5 ++ 2 files changed, 91 insertions(+) create mode 100644 tests/engine/output_processor/test_stop_checker.py diff --git a/tests/engine/output_processor/test_stop_checker.py b/tests/engine/output_processor/test_stop_checker.py new file mode 100644 index 0000000000000..ae54c83605e11 --- /dev/null +++ b/tests/engine/output_processor/test_stop_checker.py @@ -0,0 +1,86 @@ +from unittest.mock import MagicMock + +import pytest +from transformers import PreTrainedTokenizer + +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.sampling_params import SamplingParams +from vllm.sequence import Logprob, Sequence, SequenceStatus + + +def sequence_with_eos(text: str, eos_token: str, + eos_token_id: int) -> Sequence: + """ + Create a Sequence that ends with an EOS token. + """ + seq = Sequence( + seq_id=0, + prompt="", + prompt_token_ids=[], + block_size=16, + eos_token_id=eos_token_id, + ) + seq.output_text = text + eos_token + + offset = eos_token_id + 1 + for i in range(offset, len(text) + offset): + seq.append_token_id(token_id=i, logprobs={i: Logprob(0.0)}) + seq.append_token_id(token_id=eos_token_id, + logprobs={eos_token_id: Logprob(0.0)}) + + seq.status = SequenceStatus.RUNNING + + return seq + + +@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [ + ("This text ends with EOS token", "", 2), +]) +@pytest.mark.parametrize("ignore_eos", [True, False, None]) +@pytest.mark.parametrize("include_stop_str_in_output", [True, False, None]) +@pytest.mark.skip_global_cleanup +def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int, + ignore_eos: bool, include_stop_str_in_output: bool): + """ + Test the behavior of the StopChecker's maybe_stop_sequence method + when an EOS token is encountered. + + This test covers: + - When the EOS token should stop the sequence and be removed from the output + - When the EOS token should stop the sequence and be included in the output + - When the EOS token should be ignored, and the sequence continues + """ + + tokenizer = MagicMock(spec=PreTrainedTokenizer) + get_tokenizer_for_seq = MagicMock(return_value=tokenizer) + stop_checker = StopChecker(max_model_len=1024, + get_tokenizer_for_seq=get_tokenizer_for_seq) + + seq = sequence_with_eos( + text=text_wo_eos, + eos_token=eos_token, + eos_token_id=eos_token_id, + ) + new_char_count = len(eos_token) + + # Note that `stop` and `stop_token_ids` are not specified + sampling_params = SamplingParams( + min_tokens=1, + ignore_eos=ignore_eos, + include_stop_str_in_output=include_stop_str_in_output) + + stop_checker.maybe_stop_sequence( + seq=seq, + new_char_count=new_char_count, + sampling_params=sampling_params, + ) + + if ignore_eos: + assert seq.status == SequenceStatus.RUNNING + assert seq.output_text == text_wo_eos + eos_token + elif include_stop_str_in_output: + assert seq.status == SequenceStatus.FINISHED_STOPPED + assert seq.output_text == text_wo_eos + eos_token + else: + assert seq.status == SequenceStatus.FINISHED_STOPPED + assert seq.output_text == text_wo_eos diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 5fb11b32bad6d..96f0d1142611b 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -48,6 +48,11 @@ def maybe_stop_sequence( # Check if the sequence has generated the EOS token. if ((not sampling_params.ignore_eos) and seq.get_last_token_id() == seq.eos_token_id): + # Remove the last EOS token unless explicitly specified + # This prevents unintended exposure of the EOS token + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + seq.output_text = seq.output_text[:-new_char_count] seq.status = SequenceStatus.FINISHED_STOPPED return From 616e600e0b092050213e79fd2a10baabb30dcf6d Mon Sep 17 00:00:00 2001 From: Marut Pandya Date: Tue, 28 May 2024 17:16:18 -0700 Subject: [PATCH 124/124] [Misc] add gpu_memory_utilization arg (#5079) Signed-off-by: pandyamarut --- benchmarks/benchmark_latency.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 3146fb33cc27e..f69d91a086a9f 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -35,7 +35,8 @@ def main(args: argparse.Namespace): use_v2_block_manager=args.use_v2_block_manager, enable_chunked_prefill=args.enable_chunked_prefill, download_dir=args.download_dir, - block_size=args.block_size) + block_size=args.block_size, + gpu_memory_utilization=args.gpu_memory_utilization) sampling_params = SamplingParams( n=args.n, @@ -214,5 +215,11 @@ def run_to_completion(profile_dir: Optional[str] = None): type=str, default=None, help='Path to save the latency results in JSON format.') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.9, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') args = parser.parse_args() main(args)