From 1a36287b89f337057ebeb5d1bee30567e985b444 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Mon, 19 Aug 2024 13:00:09 +0800 Subject: [PATCH 01/15] [Bugfix] Fix xpu build (#7644) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 9e34433eff0d8..ef599b613667b 100644 --- a/setup.py +++ b/setup.py @@ -279,7 +279,7 @@ def _build_custom_ops() -> bool: def _build_core_ext() -> bool: - return not _is_neuron() and not _is_tpu() and not _is_openvino() + return not (_is_neuron() or _is_tpu() or _is_openvino() or _is_xpu()) def get_hipcc_rocm_version(): From df845b2b46c3e30f5bd3e3be286285ed148323fc Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 19 Aug 2024 09:29:31 -0700 Subject: [PATCH 02/15] [Misc] Remove Gemma RoPE (#7638) --- vllm/model_executor/layers/rotary_embedding.py | 15 --------------- vllm/model_executor/models/gemma.py | 8 +++----- vllm/model_executor/models/gemma2.py | 10 ++++------ 3 files changed, 7 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index e6ee2b967c8da..0562b71aa7493 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -93,11 +93,6 @@ def __init__( def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" - # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. - # However, we use `torch.arange(..., dtype=torch.float)` instead to - # avoid numerical issues with large base values (e.g., 10000000). - # This may cause a slight numerical difference between the HF - # implementation and ours. # NOTE(woosuk): To exactly match the HF implementation, we need to # use CPU to compute the cache and then move it to GPU. However, we # create the cache on GPU for faster initialization. This may cause @@ -724,16 +719,6 @@ def forward( return query, key -class GemmaRotaryEmbedding(RotaryEmbedding): - - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: - # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 - inv_freq = 1.0 / (base**( - torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() / - self.rotary_dim)) - return inv_freq - - class Llama3RotaryEmbedding(RotaryEmbedding): def __init__( diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 14d1578863e5e..7a9ee3d9477ca 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -148,14 +148,12 @@ def __init__(self, quant_config=quant_config, ) - # TODO(woosuk): Use the `get_rope` interface. - self.rotary_emb = GemmaRotaryEmbedding( + self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, - max_position_embeddings=max_position_embeddings, + max_position=max_position_embeddings, base=self.rope_theta, is_neox_style=True, - dtype=torch.get_default_dtype(), ) self.attn = Attention(self.num_heads, self.head_dim, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index aa9cff02283c0..ff547c2c3b8ab 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -32,7 +32,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -130,14 +130,12 @@ def __init__(self, bias=config.attention_bias, quant_config=quant_config, ) - # TODO(woosuk): Use the `get_rope` interface. - self.rotary_emb = GemmaRotaryEmbedding( + self.rotary_emb = get_rope( self.head_dim, - self.head_dim, - max_position_embeddings, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, base=self.rope_theta, is_neox_style=True, - dtype=torch.get_default_dtype(), ) # FIXME(woosuk): While Gemma 2 uses sliding window attention for every From 3ac50b47d0f718364d05b1bcb93743e15be6a37c Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 19 Aug 2024 11:52:07 -0700 Subject: [PATCH 03/15] [MISC] Add prefix cache hit rate to metrics (#7606) --- tests/core/block/test_prefix_caching_block.py | 26 +++++++++ tests/prefix_caching/test_prefix_caching.py | 7 +++ vllm/core/block/common.py | 53 +++++++++++++++++++ vllm/core/block/cpu_gpu_block_allocator.py | 5 ++ vllm/core/block/interfaces.py | 10 ++++ vllm/core/block/naive_block.py | 3 ++ vllm/core/block/prefix_caching_block.py | 10 +++- vllm/core/block_manager_v1.py | 31 +++++++++-- vllm/core/block_manager_v2.py | 3 ++ vllm/core/embedding_model_block_manager.py | 4 ++ vllm/core/evictor_v2.py | 15 +++--- vllm/core/interfaces.py | 6 +++ vllm/core/scheduler.py | 5 +- vllm/engine/llm_engine.py | 12 ++++- vllm/engine/metrics.py | 23 +++++++- vllm/engine/metrics_types.py | 3 ++ 16 files changed, 200 insertions(+), 16 deletions(-) diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 5fb8ec06cfa03..c2226870c2e83 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -682,6 +682,32 @@ def test_eviction_order(num_blocks: int, block_size: int, seed: int): assert new_block[0].block_id == last_block_id + # Test case for cache mertics + @staticmethod + def test_metric(): + block_size = 16 + allocator = PrefixCachingBlockAllocator(num_blocks=4, + block_size=block_size) + # Test when no query (0/0) + assert allocator.get_prefix_cache_hit_rate() == 0.0 + + token_ids = list(range(block_size)) + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids) + # Test 0/1 hit rate + assert allocator.get_prefix_cache_hit_rate() == 0.0 + + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids) + # Test 1/2 hit rate + assert allocator.get_prefix_cache_hit_rate() == 0.5 + + # Test more than one block + for _ in range(2, 1005): + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids) + assert allocator.get_prefix_cache_hit_rate() > 0.99 + @staticmethod def create_immutable_chain( block_size: int, diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 9821dbd066a59..2dff84b812b89 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -34,6 +34,9 @@ def test_block_allocator( assert (first_block == second_block) assert (second_block.ref_count == 2) + # Check metric: 1 hit of 2 queries + assert block_allocator.get_prefix_cache_hit_rate() == 0.5 + # Free the first_block and confirm that the ref_count is correctly # decremented on the second block block_allocator.free(first_block) @@ -48,6 +51,10 @@ def test_block_allocator( assert (first_block == second_block) assert (first_block.block_hash == block_hash) + # Allocate one more time to get 3/4 hit rate for easy checking + block_allocator.allocate(block_hash, 0) + assert block_allocator.get_prefix_cache_hit_rate() == 0.75 + @pytest.mark.parametrize("num_blocks", [16]) def test_eviction(num_blocks: int, ): diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index 1e808e21b72e5..eb190adfbe802 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -1,4 +1,5 @@ from collections import deque +from dataclasses import dataclass from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple from vllm.core.block.interfaces import Block, BlockAllocator @@ -282,6 +283,58 @@ def ids(self) -> List[int]: return self._block_ids +@dataclass +class CacheMetricData: + """A utility dataclass to maintain cache metric. + To avoid overflow, we maintain the hit rate in block granularity, so that + we can maintain a single hit rate for n_completed_block x block_size, + and calculate the real time hit rate by the following: + BS = The number of queries per block. + nB = The number of completed blocks. + HR = hit rate of (nB x BS) queries. + Q = current number of queries (< BS). + H = current number of hits (< BS). + hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS) + """ + num_completed_blocks: int = 0 + completed_block_cache_hit_rate: float = 0.0 + num_incompleted_block_queries: int = 0 + num_incompleted_block_hit: int = 0 + block_size: int = 1000 + + def query(self, hit: bool): + self.num_incompleted_block_queries += 1 + self.num_incompleted_block_hit += 1 if hit else 0 + + # When a block is completed, update the cache hit rate + # and reset the incomplete numbers. + if self.num_incompleted_block_queries == self.block_size: + hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + self.completed_block_cache_hit_rate = ( + self.completed_block_cache_hit_rate * self.num_completed_blocks + + hit_rate) / (self.num_completed_blocks + 1) + self.num_incompleted_block_queries = 0 + self.num_incompleted_block_hit = 0 + self.num_completed_blocks += 1 + + def get_hit_rate(self): + incomplete_ratio = self.num_incompleted_block_queries / self.block_size + total_blocks = self.num_completed_blocks + incomplete_ratio + if total_blocks == 0: + return 0.0 + + completed_block_hit, incompleted_block_hit = 0.0, 0.0 + if self.num_completed_blocks > 0: + completed_block_hit = (self.completed_block_cache_hit_rate * + self.num_completed_blocks) + if self.num_incompleted_block_queries > 0: + incompleted_hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio) + return (completed_block_hit + incompleted_block_hit) / total_blocks + + def get_all_blocks_recursively(last_block: Block) -> List[Block]: """Retrieves all the blocks in a sequence starting from the last block. diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 5287cd9c1bfb3..c6330df2a485a 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -323,6 +323,11 @@ def get_common_computed_block_ids( def all_block_ids(self) -> FrozenSet[int]: return frozenset(self._block_ids_to_allocator.keys()) + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + assert device in self._allocators + return self._allocators[device].get_prefix_cache_hit_rate() + def get_and_reset_swaps(self) -> List[Tuple[int, int]]: """Returns and clears the mapping of source to destination block IDs. Will be called after every swapping operations for now, and after every diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index ab39832bc1f6e..f26bc761c9967 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -186,6 +186,11 @@ def get_num_blocks_touched(self, num_lookahead_slots: int = 0) -> int: pass + @abstractmethod + def get_prefix_cache_hit_rate(self) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + class NoFreeBlocksError(ValueError): pass @@ -278,3 +283,8 @@ def allocate_or_get_null_block(self) -> Block: There is at most one null block per allocator. """ pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 14a62c2e7190e..1643fd69c58ab 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -341,6 +341,9 @@ def swap_in(self, blocks: List[Block]) -> None: block.block_id = block_id # Assign block_id + def get_prefix_cache_hit_rate(self) -> float: + return -1 + class NaiveBlock(Block): """An implementation of the Block class that does not support prefix diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index e145eeba2d66e..432a6651ab07a 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,9 +1,8 @@ """Token blocks.""" - from os.path import commonprefix from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple -from vllm.core.block.common import (CopyOnWriteTracker, +from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.core.block.naive_block import (BlockPool, NaiveBlock, @@ -107,6 +106,8 @@ def __init__( self._cow_tracker = CopyOnWriteTracker( refcounter=self._refcounter.as_readonly()) + self.metric_data = CacheMetricData() + # Implements Block.Factory. def _create_block( self, @@ -155,9 +156,11 @@ def allocate_immutable_block(self, cached_block_id = self._cached_blocks.get(block.content_hash, None) if cached_block_id is not None: + self.metric_data.query(hit=True) block.block_id = cached_block_id self._incr_refcount_cached_block(block) return block + self.metric_data.query(hit=False) self._block_pool.free_block(block) # No cached block => Allocate a new block @@ -404,6 +407,9 @@ def get_physical_block_id(self, absolute_id: int) -> int: def all_block_ids(self) -> FrozenSet[int]: return self._hashless_allocator.all_block_ids + def get_prefix_cache_hit_rate(self) -> float: + return self.metric_data.get_hit_rate() + def is_block_cached(self, block: Block) -> bool: assert block.content_hash is not None if block.content_hash in self._cached_blocks: diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index ad26d3c516ff0..0af04399a4b31 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -8,6 +8,7 @@ from typing import Set, Tuple from vllm.block import BlockTable, PhysicalTokenBlock +from vllm.core.block.common import CacheMetricData from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor from vllm.core.interfaces import AllocStatus, BlockSpaceManager @@ -60,6 +61,11 @@ def contains_block(self, block_hash: int) -> bool: def update_hash(self, block_hash: int, block: PhysicalTokenBlock): pass + @abstractmethod + def get_prefix_cache_hit_rate(self) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + class CachedBlockAllocator(BlockAllocatorBase): """Manages free physical token blocks for a device. @@ -85,6 +91,8 @@ def __init__(self, self.default_hash_ctr = count() + self.cache_metric_data = CacheMetricData() + def allocate_block(self, block_hash: int, num_hashed_tokens: int) -> PhysicalTokenBlock: if self.current_num_blocks == self.num_blocks: @@ -105,15 +113,17 @@ def allocate(self, num_hashed_tokens: int = 0) -> PhysicalTokenBlock: if block_hash is None: block_hash = next(self.default_hash_ctr) + if block_hash in self.evictor: assert block_hash not in self.cached_blocks block = self.evictor.remove(block_hash) assert block.ref_count == 0 self.cached_blocks[block_hash] = block - block.ref_count += 1 - assert block.block_hash == block_hash - return block - if block_hash not in self.cached_blocks: + + if block_hash in self.cached_blocks: + self.cache_metric_data.query(hit=True) + else: + self.cache_metric_data.query(hit=False) self.cached_blocks[block_hash] = self.allocate_block( block_hash, num_hashed_tokens) block = self.cached_blocks[block_hash] @@ -150,6 +160,9 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock): del self.cached_blocks[old_hash] self.cached_blocks[block_hash] = block + def get_prefix_cache_hit_rate(self) -> float: + return self.cache_metric_data.get_hit_rate() + class UncachedBlockAllocator(BlockAllocatorBase): """Manages free physical token blocks for a device. @@ -209,6 +222,9 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock): raise NotImplementedError( "Invalid codepath for uncached block allocator.") + def get_prefix_cache_hit_rate(self) -> float: + return -1 + class BlockSpaceManagerV1(BlockSpaceManager): """Manages the mapping between logical and physical token blocks.""" @@ -705,3 +721,10 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup): if self.enable_caching: for seq in seq_group.get_seqs(): self.compute_full_blocks_in_seq(seq) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + if device == Device.GPU: + return self.gpu_allocator.get_prefix_cache_hit_rate() + if device == Device.CPU: + return self.cpu_allocator.get_prefix_cache_hit_rate() + raise ValueError(f"Invalid device: {device}") diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b48ea1b19b82a..b7d9451f18067 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -441,6 +441,9 @@ def get_num_free_gpu_blocks(self) -> int: def get_num_free_cpu_blocks(self) -> int: return self.block_allocator.get_num_free_blocks(Device.CPU) + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_allocator.get_prefix_cache_hit_rate(device) + def _can_swap(self, seq_group: SequenceGroup, device: Device, diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index f2d67306d7ceb..3d864a73f91d0 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -2,6 +2,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device class EmbeddingModelBlockSpaceManager(BlockSpaceManager): @@ -81,3 +82,6 @@ def get_common_computed_block_ids(self, def mark_blocks_as_computed(self, seq_group: SequenceGroup): pass + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return -1 diff --git a/vllm/core/evictor_v2.py b/vllm/core/evictor_v2.py index 5b1a208b7c866..0b943e6e65f1c 100644 --- a/vllm/core/evictor_v2.py +++ b/vllm/core/evictor_v2.py @@ -85,19 +85,21 @@ def evict(self) -> Tuple[int, int]: if len(self.free_table) == 0: raise ValueError("No usable cache memory left") - evicted_block = next(iter(self.free_table.values())) - evicted_block_id = next(iter(self.free_table.keys())) + evicted_block, evicted_block_id = None, None # The blocks with the lowest timestamps should be placed consecutively # at the start of OrderedDict. Loop through all these blocks to # find the one with maximum number of hashed tokens. for _id, block in self.free_table.items(): + if evicted_block is None: + evicted_block, evicted_block_id = block, _id + continue if evicted_block.last_accessed < block.last_accessed: break - if (evicted_block.last_accessed == block.last_accessed and - evicted_block.num_hashed_tokens < block.num_hashed_tokens): - evicted_block = block - evicted_block_id = _id + if evicted_block.num_hashed_tokens < block.num_hashed_tokens: + evicted_block, evicted_block_id = block, _id + assert evicted_block is not None + assert evicted_block_id is not None self.free_table.pop(evicted_block_id) return evicted_block_id, evicted_block.content_hash @@ -110,7 +112,6 @@ def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, def update(self, block_id: int, last_accessed: float): self.free_table[block_id].last_accessed = last_accessed - self.free_table.move_to_end(block_id) def remove(self, block_id: int): if block_id not in self.free_table: diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 8759ee06795b8..becd0d2e7f849 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -5,6 +5,7 @@ from typing import Tuple from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device class AllocStatus(enum.Enum): @@ -116,3 +117,8 @@ def get_common_computed_block_ids( @abstractmethod def mark_blocks_as_computed(self, seq_group: SequenceGroup): pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 802359d2283f7..3b716e32032c1 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -14,7 +14,7 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceStatus) -from vllm.utils import PyObjectCache +from vllm.utils import Device, PyObjectCache logger = init_logger(__name__) @@ -447,6 +447,9 @@ def has_unfinished_seqs(self) -> bool: return len(self.waiting) != 0 or len(self.running) != 0 or len( self.swapped) != 0 + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_manager.get_prefix_cache_hit_rate(device) + def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fcf45a38b9425..36cb6ce795f3e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -47,7 +47,7 @@ AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter +from vllm.utils import Counter, Device from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -1390,6 +1390,13 @@ def _get_stats( for scheduler in self.scheduler) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) + # Prefix Cache Hit Rate. Note that we always use + # the cache hit rate of the first virtual engine. + cpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.CPU) + gpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.GPU) + # Iteration stats num_prompt_tokens_iter = 0 num_generation_tokens_iter = 0 @@ -1498,6 +1505,9 @@ def _get_stats( # KV Cache Usage in % gpu_cache_usage_sys=gpu_cache_usage_sys, cpu_cache_usage_sys=cpu_cache_usage_sys, + # Prefix Cache Hit Rate + cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, + gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, # Iteration stats num_prompt_tokens_iter=num_prompt_tokens_iter, diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 1071786c27cd6..74277cae7c8ef 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -71,6 +71,17 @@ def __init__(self, labelnames: List[str], max_model_len: int): documentation="CPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames, multiprocess_mode="sum") + # Prefix caching block hit rate + self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls( + name="vllm:cpu_prefix_cache_hit_rate", + documentation="CPU prefix cache block hit rate.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls( + name="vllm:gpu_prefix_cache_hit_rate", + documentation="GPU prefix cache block hit rate.", + labelnames=labelnames, + multiprocess_mode="sum") # Iteration stats self.counter_num_preemption = self._counter_cls( @@ -351,7 +362,13 @@ def log(self, stats: Stats) -> None: stats.gpu_cache_usage_sys * 100, stats.cpu_cache_usage_sys * 100, ) - + if (stats.cpu_prefix_cache_hit_rate >= 0 + or stats.gpu_prefix_cache_hit_rate >= 0): + logger.info( + "Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%", + stats.gpu_prefix_cache_hit_rate * 100, + stats.cpu_prefix_cache_hit_rate * 100, + ) if self.spec_decode_metrics is not None: logger.info( self._format_spec_decode_metrics_str( @@ -423,6 +440,10 @@ def _log_prometheus(self, stats: Stats) -> None: stats.gpu_cache_usage_sys) self._log_gauge(self.metrics.gauge_cpu_cache_usage, stats.cpu_cache_usage_sys) + self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate, + stats.cpu_prefix_cache_hit_rate) + self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate, + stats.gpu_prefix_cache_hit_rate) # Iteration level data self._log_counter(self.metrics.counter_num_preemption, diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 7449aafc5aecb..1eccb23593408 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -32,6 +32,9 @@ class Stats: # KV Cache Usage in % gpu_cache_usage_sys: float cpu_cache_usage_sys: float + # Prefix caching block hit rate + cpu_prefix_cache_hit_rate: float + gpu_prefix_cache_hit_rate: float # Iteration stats (should have _iter suffix) num_prompt_tokens_iter: int From dad961ef5ca3893b78224323ec943dce9f52f868 Mon Sep 17 00:00:00 2001 From: Ali Panahi <64020589+c3-ali@users.noreply.github.com> Date: Mon, 19 Aug 2024 13:47:00 -0700 Subject: [PATCH 04/15] [Bugfix] fix lora_dtype value type in arg_utils.py - part 2 (#5428) --- vllm/engine/arg_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8fca2cc049958..b23e166dc0d7b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -5,6 +5,8 @@ from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type, Union) +import torch + import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, @@ -113,7 +115,7 @@ class EngineArgs: fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 long_lora_scaling_factors: Optional[Tuple[float]] = None - lora_dtype: str = 'auto' + lora_dtype: Optional[Union[str, torch.dtype]] = 'auto' max_cpu_loras: Optional[int] = None device: str = 'auto' num_scheduler_steps: int = 1 From 47b65a550866c7ffbd076ecb74106714838ce7da Mon Sep 17 00:00:00 2001 From: William Lin Date: Mon, 19 Aug 2024 13:52:13 -0700 Subject: [PATCH 05/15] [core] Multi Step Scheduling (#7000) Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> --- .buildkite/test-pipeline.yaml | 9 + tests/multi_step/__init__.py | 0 tests/multi_step/test_correctness.py | 85 +++++ tests/worker/test_model_input.py | 77 +++++ vllm/engine/arg_utils.py | 7 +- vllm/engine/async_llm_engine.py | 135 +++++++- vllm/executor/gpu_executor.py | 14 +- vllm/executor/ray_gpu_executor.py | 3 + vllm/sequence.py | 10 +- vllm/worker/model_runner_base.py | 46 ++- vllm/worker/multi_step_model_runner.py | 453 +++++++++++++++++++++++++ vllm/worker/multi_step_worker.py | 189 +++++++++++ vllm/worker/worker_base.py | 10 +- 13 files changed, 1004 insertions(+), 34 deletions(-) create mode 100644 tests/multi_step/__init__.py create mode 100644 tests/multi_step/test_correctness.py create mode 100644 vllm/worker/multi_step_model_runner.py create mode 100644 vllm/worker/multi_step_worker.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7babffc62f431..d583610a78655 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -311,6 +311,15 @@ steps: - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py +- label: Multi-step Tests (4 GPUs) # 10min + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/ + - tests/multi_step/test_correctness.py + commands: + - pytest -v -s multi_step/test_correctness.py + - label: Pipeline Parallelism Test # 23min working_dir: "/vllm-workspace/tests" num_gpus: 4 diff --git a/tests/multi_step/__init__.py b/tests/multi_step/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/multi_step/test_correctness.py b/tests/multi_step/test_correctness.py new file mode 100644 index 0000000000000..bc14311c66424 --- /dev/null +++ b/tests/multi_step/test_correctness.py @@ -0,0 +1,85 @@ +# Test the AsyncLLMEngine with multi-step-decoding + +from typing import List + +import pytest + +from ..utils import RemoteOpenAIServer + +MODELS = [ + "JackFram/llama-160m", +] +NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps +NUM_PROMPTS = [10] + +DEFAULT_SERVER_ARGS: List[str] = [ + "--disable-log-requests", + "--use-v2-block-manager", + "--worker-use-ray", + "--gpu-memory-utilization", + "0.85", + "--swap-space", + "16", +] + + +async def completions_with_server_args(prompts: List[str], model_name: str, + server_cli_args: List[str]): + + outputs = None + with RemoteOpenAIServer(model_name, server_cli_args) as server: + client = server.get_async_client() + outputs = await client.completions.create(model=model_name, + prompt=prompts, + temperature=0, + stream=False, + max_tokens=5) + assert outputs is not None + + return outputs + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize(("tp_size, pp_size"), [ + (1, 1), + (2, 2), +]) +@pytest.mark.parametrize("eager_mode", [False, True]) +@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) +@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.asyncio +async def test_multi_step(example_prompts, model: str, tp_size: int, + pp_size: int, eager_mode: int, + num_scheduler_steps: int, num_prompts: int): + + prompts = example_prompts + if len(prompts) < num_prompts: + prompts = prompts * ((num_prompts // len(prompts)) + 1) + prompts = prompts[:num_prompts] + assert len(prompts) == num_prompts + + server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"] + ms_server_args = DEFAULT_SERVER_ARGS + \ + ["--num-scheduler-steps", f"{num_scheduler_steps}"] + + if eager_mode: + ms_server_args.append("--enforce-eager") + + distributed_args = [ + "--tensor-parallel-size", + str(tp_size), + "--pipeline-parallel-size", + str(pp_size), + ] + + ref_completions = await completions_with_server_args( + prompts, model, server_args + distributed_args) + test_completions = await completions_with_server_args( + prompts, model, ms_server_args + distributed_args) + + def get_text_generations(completions): + return [x.text for x in completions.choices] + + ref_generations = get_text_generations(ref_completions) + test_generations = get_text_generations(test_completions) + assert ref_generations == test_generations diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 2126fafb2323b..a57fdac803e42 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -10,6 +10,7 @@ from vllm.worker.embedding_model_runner import ( ModelInputForGPUWithPoolingMetadata) from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata +from vllm.worker.multi_step_model_runner import StatefulModelInput class MockAttentionBackend(AttentionBackend): @@ -154,3 +155,79 @@ def test_embedding_model_runner_input(): None) == getattr(attn_metadata, field.name, None) # Pooling metadata is not broadcast. assert received_model_input.pooling_metadata is None + + +def test_multi_step_model_runner_input(): + sampling_metadata = SamplingMetadata( + ["seq_group"], + "selected_token_indices", + "categorized_sample_indices", + "num_prompts", + ) + attn_metadata = AttentionMetadata( + num_prefills=1, + num_prefill_tokens=2, + num_decode_tokens=3, + slot_mapping=torch.zeros(1), + ) + frozen_model_input = ModelInputForGPUWithSamplingMetadata( + input_tokens=torch.ones(10), + input_positions=torch.ones(10), + sampling_metadata=sampling_metadata, + attn_metadata=attn_metadata) + + model_input = StatefulModelInput( + frozen_model_input=frozen_model_input, + is_last_step=True, + is_first_multi_step=False, + current_step=4, + last_sampled_token_ids=torch.ones((10, 1)), + is_multi_step=True, + num_queries=8, + num_seqs=5, + cached_outputs=[], + ) + + assert isinstance(model_input, StatefulModelInput) + + # Test round trip serialization. + tensor_dict = model_input.as_broadcastable_tensor_dict() + attn_backend = MockAttentionBackend() + received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict( + tensor_dict, attn_backend=attn_backend)) + + receieved_frozen_input = received_model_input.frozen_model_input + + # Check that received copy has correct values. + assert isinstance(received_model_input, StatefulModelInput) + assert receieved_frozen_input.input_tokens is not None + assert (receieved_frozen_input.input_tokens == + frozen_model_input.input_tokens).all() + assert receieved_frozen_input.input_positions is not None + assert (receieved_frozen_input.input_positions == + frozen_model_input.input_positions).all() + assert receieved_frozen_input.multi_modal_kwargs is None + assert (frozen_model_input.multi_modal_kwargs == + frozen_model_input.multi_modal_kwargs) + assert receieved_frozen_input.lora_requests is None + assert (receieved_frozen_input.lora_requests == + frozen_model_input.lora_requests) + assert receieved_frozen_input.lora_mapping is None + assert ( + receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping) + for field in dataclasses.fields(AttentionMetadata): + assert getattr(receieved_frozen_input.attn_metadata, field.name, + None) == getattr(attn_metadata, field.name, None) + # For sampling metadata, only selected_token_indices is copied. + assert (receieved_frozen_input.sampling_metadata.selected_token_indices == + sampling_metadata.selected_token_indices) + assert receieved_frozen_input.sampling_metadata.seq_groups is None + + # check non frozen fields + assert received_model_input.is_last_step == model_input.is_last_step + assert (received_model_input.is_first_multi_step == + model_input.is_first_multi_step) + assert received_model_input.current_step == model_input.current_step + assert (received_model_input.last_sampled_token_ids == + model_input.last_sampled_token_ids).all() + assert received_model_input.is_multi_step == model_input.is_multi_step diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b23e166dc0d7b..a3bb87bbe6748 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -853,6 +853,12 @@ def create_engine_config(self, ) -> EngineConfig: "in low performance due to small KV cache space. Consider " "setting --max-model-len to a smaller value.", max_model_len) + if self.num_scheduler_steps > 1 and not self.use_v2_block_manager: + self.use_v2_block_manager = True + logger.warning( + "Enabled BlockSpaceManagerV2 because it is " + "required for multi-step (--num-scheduler-steps > 1)") + speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, target_parallel_config=parallel_config, @@ -881,7 +887,6 @@ def create_engine_config(self, ) -> EngineConfig: ) if self.num_scheduler_steps > 1: - raise NotImplementedError("Multi-step is not yet supported.") if speculative_config is not None: raise ValueError("Speculative decoding is not supported with " "multi-step (--num-scheduler-steps > 1)") diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index dced804fccca9..6385d3ca2297e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,9 +1,11 @@ import asyncio import time +from dataclasses import dataclass from functools import partial from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union) +import torch from transformers import PreTrainedTokenizer from typing_extensions import assert_never @@ -27,7 +29,8 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) from vllm.usage.usage_lib import UsageContext from vllm.utils import print_warning_once @@ -249,9 +252,25 @@ def has_new_requests(self): return not self._new_requests.empty() +@dataclass +class SchedulerOutputState: + """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" + last_output: Optional[SamplerOutput] = None + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + scheduler_outputs: Optional[SchedulerOutputs] = None + + class _AsyncLLMEngine(LLMEngine): """Extension of LLMEngine to add async methods.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + pipeline_parallel_size = \ + self.parallel_config.pipeline_parallel_size + self.cached_scheduler_outputs = [ + SchedulerOutputState() for _ in range(pipeline_parallel_size) + ] + async def step_async( self, virtual_engine: int ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: @@ -264,13 +283,39 @@ async def step_async( and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - seq_group_metadata_list, scheduler_outputs = self.scheduler[ - virtual_engine].schedule() + # these are cached outputs from previous iterations. None if on first + # iteration + cached_outputs = self.cached_scheduler_outputs[virtual_engine] + seq_group_metadata_list = cached_outputs.seq_group_metadata_list + scheduler_outputs = cached_outputs.scheduler_outputs + # skip the scheduler if there are any remaining steps in the seq groups. + # This ensures that the scheduler is only called again when the current + # batch has completed. + if not self._has_remaining_steps(seq_group_metadata_list): + seq_group_metadata_list, scheduler_outputs = self.scheduler[ + virtual_engine].schedule() + + if (self.scheduler_config.is_multi_step + and scheduler_outputs.num_lookahead_slots > 0): + # cache the scheduler outputs for the next iteration if we have + # lookahead slots + self._cache_scheduler_outputs_for_multi_step( + virtual_engine, seq_group_metadata_list, scheduler_outputs) + + assert seq_group_metadata_list is not None + assert scheduler_outputs is not None if not scheduler_outputs.is_empty(): - # Execute the model. finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() + + # Check if we have a cached last_output from the previous iteration. + # For supporting PP this is probably the best way to pass the + # sampled_token_ids, as a separate broadcast over all the PP stages + # will cause one virtual engine's microbatch to block the pipeline. + last_sampled_token_ids = \ + self._get_last_sampled_token_ids(virtual_engine) + execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, @@ -279,15 +324,35 @@ async def step_async( virtual_engine=virtual_engine, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_requests_ids=finished_requests_ids) + finished_requests_ids=finished_requests_ids, + # We use ExecuteModelRequest to pass the last sampled_token_ids + # to each of the non-last PP stages for in-place prepare_input. + last_sampled_token_ids=last_sampled_token_ids) + # Execute the model. output = await self.model_executor.execute_model_async( execute_model_req) + # we need to do this here so that last step's sampled_token_ids can + # be passed to the next iteration for PP. + if self.scheduler_config.is_multi_step: + self._update_cached_scheduler_output(virtual_engine, output) else: output = [] - request_outputs = self._process_model_outputs( - output, scheduler_outputs.scheduled_seq_groups, - scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + # Finish the current step for all the sequence groups. + if self.scheduler_config.is_multi_step: + for seq_group in seq_group_metadata_list: + seq_group.finish_step() + + if not self._has_remaining_steps(seq_group_metadata_list): + # clear the cache if we have finished all the steps + if self.scheduler_config.is_multi_step: + self.cached_scheduler_outputs[ + virtual_engine] = SchedulerOutputState() + request_outputs = self._process_model_outputs( + output, scheduler_outputs.scheduled_seq_groups, + scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + else: + request_outputs = [] # Log stats. self.do_log_stats(scheduler_outputs, output) @@ -297,6 +362,60 @@ async def step_async( return request_outputs + def _has_remaining_steps( + self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] + ) -> bool: + if (not self.scheduler_config.is_multi_step + or not seq_group_metadata_list): + return False + + # TODO(will) this is a sanity check for nowto make sure that all the + # seqs are on the same steps. Eventually we will want to do some sort of + # dynamic scheduling when doing multi-step decoding. + ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps + if any([ + seq_group.state.remaining_steps != ref_remaining_steps + for seq_group in seq_group_metadata_list[1:] + ]): + raise AssertionError(("All running sequence groups should " + "have the same remaining steps.")) + + return ref_remaining_steps > 0 + + def _cache_scheduler_outputs_for_multi_step( + self, virtual_engine: int, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + scheduler_outputs: SchedulerOutputs) -> None: + self.cached_scheduler_outputs[ + virtual_engine].seq_group_metadata_list = seq_group_metadata_list + self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \ + scheduler_outputs + self.cached_scheduler_outputs[virtual_engine].last_output = None + + def _get_last_sampled_token_ids( + self, virtual_engine: int) -> Optional[torch.Tensor]: + cached_last_output = self.cached_scheduler_outputs[ + virtual_engine].last_output + if (self.scheduler_config.is_multi_step + and self.parallel_config.pipeline_parallel_size > 1 + and cached_last_output is not None + and cached_last_output.sampled_token_ids_cpu is not None): + return cached_last_output.sampled_token_ids_cpu + return None + + def _update_cached_scheduler_output( + self, virtual_engine: int, + output: List[Optional[SamplerOutput]]) -> None: + if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0 + and output[0] is not None): + last_output = output[-1] + assert last_output is not None + assert last_output.sampled_token_ids_cpu is not None + assert last_output.sampled_token_ids is None + assert last_output.sampled_token_probs is None + self.cached_scheduler_outputs[ + virtual_engine].last_output = last_output + async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 55976f430254c..7d40607e81791 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -69,13 +69,19 @@ def _get_create_worker_kwargs( distributed_init_method: Optional[str] = None) -> Dict: worker_kwargs = self._get_worker_kwargs(local_rank, rank, distributed_init_method) - if self.speculative_config is None: - worker_kwargs.update(worker_module_name="vllm.worker.worker", - worker_class_name="Worker") - else: + + if self.scheduler_config.is_multi_step: + worker_kwargs.update( + worker_module_name="vllm.worker.multi_step_worker", + worker_class_name="MultiStepWorker") + elif self.speculative_config: worker_kwargs.update( worker_module_name="vllm.spec_decode.spec_decode_worker", worker_class_name="create_spec_worker") + else: + worker_kwargs.update(worker_module_name="vllm.worker.worker", + worker_class_name="Worker") + return worker_kwargs def _create_worker(self, diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 3a08ab4dbfd44..4c38cd1cbd546 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -94,6 +94,9 @@ def _get_worker_wrapper_args(self) -> Dict[str, Any]: if self.speculative_config is not None: worker_module_name = "vllm.spec_decode.spec_decode_worker" worker_class_name = "create_spec_worker" + elif self.scheduler_config.is_multi_step: + worker_module_name = "vllm.worker.multi_step_worker" + worker_class_name = "MultiStepWorker" else: worker_module_name = "vllm.worker.worker" worker_class_name = "Worker" diff --git a/vllm/sequence.py b/vllm/sequence.py index b15955cde76cf..206da192193dc 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -9,7 +9,6 @@ Tuple, Union, cast) import msgspec -import numpy import torch from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs @@ -1082,7 +1081,10 @@ class SamplerOutput( # On-device tensor containing the sampled token ids. sampled_token_ids: Optional[torch.Tensor] = None - sampled_token_ids_numpy: Optional[numpy.ndarray] = None + # CPU tensor containing the sampled token ids. Used during multi-step to + # return the sampled token ids from last rank to AsyncLLMEngine to be + # 'broadcasted' to all other PP ranks for next step. + sampled_token_ids_cpu: Optional[torch.Tensor] = None # Spec decode metrics populated by workers. spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None @@ -1257,9 +1259,7 @@ def is_last_step(self) -> bool: assert len(self.seq_group_metadata_list) > 0 first_seq_group = self.seq_group_metadata_list[0] assert first_seq_group.state is not None - num_steps = first_seq_group.state.num_steps - current_step = first_seq_group.state.current_step - return num_steps - current_step == 1 + return first_seq_group.state.remaining_steps == 1 @property def current_step(self) -> int: diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 46ac16b504bf4..90c39407d7266 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -14,7 +14,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata -T = TypeVar('T', bound="ModelRunnerInputBase") +T = TypeVar('T', bound="BroadcastableModelInput") def _add_attn_metadata_broadcastable_dict( @@ -81,18 +81,26 @@ def _add_sampling_metadata_broadcastable_dict( sampling_metadata.selected_token_indices) -@dataclasses.dataclass(frozen=True) -class ModelRunnerInputBase(ABC): - """Local inputs to each worker's model runner. May contain - device-specific data. Different worker backends may have different methods - of converting from the global ExecuteModelRequest produced by the LLM - engine to the worker-local ModelRunnerInputBase objects. - - Model runners that support multi-GPU execution should define a - ModelRunnerInputBase subclass, add their required fields, and specify how to - serialize/deserialize a ModelInput for broadcast between workers. +def _init_frozen_model_input_from_tensor_dict( + frozen_model_input_cls: Type["ModelRunnerInputBase"], + tensor_dict: Dict[str, Any]) -> Dict[str, Any]: """ + Helper method to initialize a frozen ModelInput based on broadcastable + """ + valid_tensor_kwargs = {} + for field in dataclasses.fields(frozen_model_input_cls): + val = tensor_dict.pop(field.name, None) + if val is not None: + valid_tensor_kwargs[field.name] = val + + frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs) + tensor_dict["frozen_model_input"] = frozen_model_input + return tensor_dict + +class BroadcastableModelInput(ABC): + + @abstractmethod def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: """ Extract broadcastable fields. Override for fields that require some @@ -109,11 +117,25 @@ def from_broadcasted_tensor_dict( ) -> T: """ Pop fields from the given tensor_dict and populate a new instance of - ModelRunnerInputBase. + BroadcastableModelInput. """ raise NotImplementedError +@dataclasses.dataclass(frozen=True) +class ModelRunnerInputBase(BroadcastableModelInput): + """Local inputs to each worker's model runner. May contain + device-specific data. Different worker backends may have different methods + of converting from the global ExecuteModelRequest produced by the LLM + engine to the worker-local ModelRunnerInputBase objects. + + Model runners that support multi-GPU execution should define a + ModelRunnerInputBase subclass, add their required fields, and specify how to + serialize/deserialize a ModelInput for broadcast between workers. + """ + pass + + class ModelRunnerInputBuilderBase(ABC, Generic[T]): """A builder to create ModelRunnerInputBase objects. """ diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py new file mode 100644 index 0000000000000..521205eca05af --- /dev/null +++ b/vllm/worker/multi_step_model_runner.py @@ -0,0 +1,453 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +try: + from vllm.attention.backends.flash_attn import FlashAttentionMetadata +except ModuleNotFoundError: + # vllm_flash_attn is not installed, use the identical ROCm FA metadata + from vllm.attention.backends.rocm_flash_attn import ( + ROCmFlashAttentionMetadata as FlashAttentionMetadata) + +import torch + +from vllm import _custom_ops as ops +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SamplerOutput, SequenceGroupMetadata, + SequenceOutput) +from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPUWithSamplingMetadata) +from vllm.worker.model_runner_base import ( + BroadcastableModelInput, _init_attn_metadata_from_tensor_dict, + _init_frozen_model_input_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +from ..model_executor.model_loader.tensorizer import TensorizerConfig + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + + +@dataclass +class ModelOutput: + """The output of a single model forward pass. + + The sampler_output_ready_event is set when the tensors in + sampler_output are ready (the model+sampler forward pass has + completed). We use the event to synchronize the GPU->CPU transfer, + which we want to only run when the data has been written to the + GPU tensors. Until the event is ready, the tensors in sampler_output + will have garbage data. + + There are two scenarios: + 1. The output tensors are ready and we can pythonize them immediately. + 2. The output tensors are not ready and we need to wait for the event to be + ready. + """ + sampler_output: SamplerOutput + sampler_output_ready_event: torch.cuda.Event + sampled_token_ids: Optional[torch.Tensor] = None + pythonized: bool = False + + def pythonize(self, input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> None: + """Pythonize the output. Blocking.""" + if not self.pythonized: + self._pythonize_sampler_output(input_metadata, copy_stream, + pinned_sampled_token_buffer, True) + self.pythonized = True + + def maybe_pythonize(self, input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> None: + """Pythonize the output if ready, else return None. Non-blocking.""" + if not self.pythonized: + self.pythonized = self._pythonize_sampler_output( + input_metadata, copy_stream, pinned_sampled_token_buffer, + False) + + def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor, + blocking: bool) -> bool: + """ + If blocking is set, will block until the forward pass for the output is + ready and pythonize the output. + """ + assert self.sampled_token_ids is not None + if not blocking and not self.sampler_output_ready_event.query(): + return False + + if blocking: + self.sampler_output_ready_event.synchronize() + with torch.cuda.stream(copy_stream): + _pythonize_sampler_output(input_metadata, self.sampler_output, + pinned_sampled_token_buffer, + self.sampled_token_ids) + return True + + +@dataclass(frozen=False) +class StatefulModelInput(BroadcastableModelInput): + # actual frozen model input dataclass passed to _base_model_runner + frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None + + # list of model outputs for each step, may not be all pythonized + cached_outputs: List[ModelOutput] = field(default_factory=list) + + # used to pass sampled token ids from the last step to the current step for + # TP workers. Used to append to end of outputs and used by advance_step + last_sampled_token_ids: Optional[torch.Tensor] = None + current_step: int = 0 + is_multi_step: bool = True + is_last_step: bool = False + is_first_multi_step: bool = False + # ping-pong data structures for multi-step to wait on the previous step + step_cuda_events: List[torch.cuda.Event] = field( + default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2) + num_seqs: int = -1 + num_queries: int = -1 + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + assert self.frozen_model_input is not None + tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict() + new_tensor_dict = { + 'last_sampled_token_ids': self.last_sampled_token_ids, + 'current_step': self.current_step, + 'is_multi_step': self.is_multi_step, + 'is_last_step': self.is_last_step, + 'is_first_multi_step': self.is_first_multi_step, + 'num_seqs': self.num_seqs, + 'num_queries': self.num_queries, + } + tensor_dict.update(new_tensor_dict) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "StatefulModelInput": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + tensor_dict = _init_frozen_model_input_from_tensor_dict( + ModelInputForGPUWithSamplingMetadata, tensor_dict) + + return cls(**tensor_dict) + + def record_step_event(self, current_stream: torch.cuda.Stream): + # record the event for the current step so that the next step can sync + # on it. We modulo by 2 to keep the events in a circular buffer and + # support any attn backends that may be supported in the future. ie + # Flashinfer would want two DecodeWrappers to overlap the CPU and GPU. + self.step_cuda_events[self.current_step & 1] = \ + torch.cuda.Event(blocking=True) + self.step_cuda_events[self.current_step & 1].record(current_stream) + + def wait_previous_step(self): + # These cuda events are an explicit synchronization to ensure that + # advance_step() (for other attn backends that may be supported in the + # future) do not clobber any data structures that is also used by any + # enqueued forwards steps. For distributed case, only a single event is + # needed, but for single GPU case, since we can let the CPU run much + # further ahead, two events allow us to overlap the advance_step with + # the previous forward (ie using two DecodeWrappers for flashinfer + # backend) + self.step_cuda_events[(self.current_step + 1) & 1].wait() + + def add_sampler_output(self, + sampler_output: SamplerOutput, + sampled_token_ids: Optional[torch.Tensor] = None): + self.cached_outputs.append( + ModelOutput(sampler_output=sampler_output, + sampler_output_ready_event=None, + sampled_token_ids=sampled_token_ids, + pythonized=False)) + + +# MutableModelInputForGPUWithMultiStepMetadata is not subclass of +# ModelInputForGPU but it wraps the actual input dataclass and adds multi-step +# metadata +# mypy: disable-error-code=type-var +class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): + # mypy: enable-error-code=type-var + + def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): + super().__init__(*args, **kwargs) + + # uses the base model runner to execute the model and wraps it with + # multi-step logic + self._base_model_runner: GPUModelRunnerBase = base_model_runner + + self.is_multi_step = self.scheduler_config.is_multi_step + # used to copy tensors from GPU to CPU asynchronously + self._copy_stream = torch.cuda.Stream() + self.pinned_sampled_token_ids: Optional[torch.Tensor] = None + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> StatefulModelInput: + model_input = (StatefulModelInput.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) + return model_input + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> StatefulModelInput: + frozen_model_input = self._base_model_runner.prepare_model_input( + seq_group_metadata_list, virtual_engine, finished_requests_ids) + + model_input = StatefulModelInput( + frozen_model_input=frozen_model_input, + num_seqs=len(frozen_model_input.seq_lens), + num_queries=len(frozen_model_input.query_lens), + ) + return model_input + + @torch.inference_mode() + def execute_model( + self, + model_input: StatefulModelInput, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + """ + Execute the model for a single step and update multi-step + metadata + """ + assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1" + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + # path for warm up runs + if not model_input.is_multi_step: + return self._base_model_runner.execute_model( + frozen_model_input, kv_caches, intermediate_tensors, num_steps) + + # make sure we skip the sampler on the lask rank and only pythonize + # if CPU is ahead. + if self.is_driver_worker and get_pp_group().is_last_rank: + if self.pinned_sampled_token_ids is None: + self.pinned_sampled_token_ids = torch.zeros( + (self.scheduler_config.max_num_seqs, 1), + dtype=torch.long, + device="cpu", + pin_memory=True) + + self._base_model_runner.model.sampler.include_gpu_probs_tensor = ( + True) + if frozen_model_input.sampling_metadata: + frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( + True) + + # some pre-execute model logic for multi-step: + # - if it's the first step, we need to reset the sampling tensors + # - if it's not the first step, we need to advance the step using the + # appended sampler output from last iteration + # - also maybe pythonize if CPU is ahead of GPU + + current_stream = torch.cuda.current_stream() + if not model_input.is_first_multi_step: + # Explicitly block on the previous step's forward to make sure we + # don't clobber any GPU tensors still in use. + # This is not needed for flashattn backend, but for other attn + # backends such as flashinfer that performs extra CPU operations on + # input metadata we may need to synchronize any CPU operations that + # might clobber enqueued forwards. (prevents CPU from running too + # far ahead if needed) + model_input.wait_previous_step() + model_input = self._advance_step( + model_input, model_input.cached_outputs[-1].sampler_output) + + # Execute the model + output = self._base_model_runner.execute_model(frozen_model_input, + kv_caches, + intermediate_tensors, + num_steps=1) + + # record the event for the current step so that the next step can sync + model_input.record_step_event(current_stream) + + if get_pp_group().is_last_rank and self.is_driver_worker: + assert len( + output + ) == 1, "MultiStepModelRunner requires single-step base_models" + + # event for the pythonization so that we only pythonize if the + # tensors are ready. May be able to be combined with the step event + output_ready_event = torch.cuda.Event() + output_ready_event.record(current_stream) + if self.parallel_config.pipeline_parallel_size > 1: + output[0].sampled_token_ids_cpu = output[ + 0].sampled_token_ids.cpu() + model_input.cached_outputs.append( + ModelOutput(output[0], output_ready_event, + output[0].sampled_token_ids, False)) + # make sure we dont try to serialize any GPU tensors + output[0].sampled_token_ids = None + output[0].sampled_token_probs = None + output[0].logprobs = None + # Pythonize the output if CPU is ahead and the previous step is + # ready. + for model_output in model_input.cached_outputs: + model_output.maybe_pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + + model_input.current_step += 1 + + if not get_pp_group().is_last_rank: + # Should be IntermediateTensors + assert isinstance(output, IntermediateTensors) + return output + if not self.is_driver_worker: + return [] + + # Pythonize the output and block if needed since it is the last step + if model_input.is_last_step: + outputs = [] + for output in model_input.cached_outputs: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + outputs.append(output.sampler_output) + return outputs + + # should be [SamplerOutput] + return output + + def _update_sampling_metadata(self, sampling_metadata, num_seqs, + num_queries): + + assert sampling_metadata.num_prompts == 0 + assert len(sampling_metadata.seq_groups) == num_queries + assert sampling_metadata.selected_token_indices.shape == ( + num_queries, ) + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 + + # Verify that all sequences are decodes + for i in range(num_queries): + seq_group = sampling_metadata.seq_groups[i] + + assert seq_group.is_prompt is False # No prompt + assert seq_group.prompt_logprob_indices == [] # No prompt + assert seq_group.sample_indices == [i] # Simple + assert seq_group.seq_len is None # Decode + assert seq_group.query_len is None # Decode + + def _advance_step(self, model_input: StatefulModelInput, + out: SamplerOutput) -> StatefulModelInput: + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.attn_metadata is not None + + num_seqs = model_input.num_seqs + num_queries = model_input.num_queries + assert num_seqs > 0 + assert num_queries > 0 + assert num_seqs >= num_queries + + attn_metadata = frozen_model_input.attn_metadata + assert isinstance(attn_metadata, FlashAttentionMetadata) + attn_metadata.advance_step(num_seqs, num_queries) + + # Update GPU tensors + ops.advance_step( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=self.block_size, + input_tokens=frozen_model_input.input_tokens, + sampled_token_ids=model_input.cached_outputs[-1].sampled_token_ids, + input_positions=frozen_model_input.input_positions, + seq_lens=attn_metadata.seq_lens_tensor, + slot_mapping=attn_metadata.slot_mapping, + block_tables=attn_metadata.block_tables) + + if frozen_model_input.seq_lens is not None: + for i in range(num_queries): + frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i] + + return model_input + + def load_model(self) -> None: + return self._base_model_runner.load_model() + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + return self._base_model_runner.save_sharded_state( + path, pattern, max_size) + + def save_tensorized_model(self, + tensorizer_config: TensorizerConfig) -> None: + return self._base_model_runner.save_tensorized_model(tensorizer_config) + + def profile_run(self) -> None: + return self._base_model_runner.profile_run() + + def remove_all_loras(self): + return self._base_model_runner.remove_all_loras() + + def capture_model(self, kv_caches: List[List]) -> None: + return self._base_model_runner.capture_model(kv_caches) + + @property + def vocab_size(self) -> int: + return self._base_model_runner.vocab_size + + +def _pythonize_sampler_output(model_input: StatefulModelInput, + output: SamplerOutput, + pinned_sampled_token_buffer: torch.Tensor, + sampled_token_ids: torch.Tensor) -> None: + """ This function is only called when the output tensors are ready. + See ModelOutput + """ + + assert model_input.frozen_model_input is not None + + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input.sampling_metadata is not None + # samples generation should have been skipped + assert not output.outputs + + pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries] + + # CPU GPU sync + pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False) + + # this will not block as the tensors are already on CPU + samples_list = pinned_buffer.tolist() + + sampling_metadata = frozen_model_input.sampling_metadata + + for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, + samples_list): + seq_ids = seq_group.seq_ids + next_token_ids = sample_result + parent_ids = [0] + seq_outputs: List[SequenceOutput] = [] + if seq_group.sampling_params.logits_processors: + assert len(seq_group.sampling_params.logits_processors) == 0, ( + "Logits Processors are not supported in multi-step decoding") + for parent_id, next_token_id in zip(parent_ids, next_token_ids): + # TODO(will): support logprobs + # Hard coded logprob + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + {next_token_id: Logprob(logprob=-1)})) + output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None)) + assert len(output.outputs) > 0 diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py new file mode 100644 index 0000000000000..6a6caba9371eb --- /dev/null +++ b/vllm/worker/multi_step_worker.py @@ -0,0 +1,189 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple + +from vllm.distributed import broadcast_tensor_dict, get_pp_group +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.worker.model_runner_base import BroadcastableModelInput +from vllm.worker.multi_step_model_runner import (MultiStepModelRunner, + StatefulModelInput) +from vllm.worker.worker import Worker, WorkerInput + + +@dataclass +class MultiStepState: + worker_input: WorkerInput + model_input: StatefulModelInput + + +class MultiStepWorker(Worker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + base_model_runner = self.model_runner + # for multi-step model, wrap the model runner with MultiStepModelRunner + self.model_runner = MultiStepModelRunner( + base_model_runner, + base_model_runner.model_config, + base_model_runner.parallel_config, + base_model_runner.scheduler_config, + base_model_runner.device_config, + base_model_runner.cache_config, + load_config=base_model_runner.load_config, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=base_model_runner.is_driver_worker, + prompt_adapter_config=base_model_runner.prompt_adapter_config, + observability_config=base_model_runner.observability_config, + ) + + pipeline_parallel_size = self.parallel_config.pipeline_parallel_size + self.multi_step_states: List[ + Optional[MultiStepState]] = [None] * pipeline_parallel_size + self.temp_output = None + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[BroadcastableModelInput, WorkerInput]: + """ + Get the driver input and broadcast it to other workers. + """ + assert self.is_driver_worker + virtual_engine = execute_model_req.virtual_engine + is_first_multi_step = execute_model_req.is_first_multi_step + if is_first_multi_step: + # on first step we prepare the worker input and model input normally + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: StatefulModelInput = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + else: + # on subsequent steps we reuse the worker input and model input + multi_step_state = self.multi_step_states[virtual_engine] + worker_input = multi_step_state.worker_input + model_input = multi_step_state.model_input + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.attn_metadata is not None + # clear the cached decode metadata so that it can be recomputed on + # the workers + frozen_model_input.attn_metadata._cached_decode_metadata = None + + model_input.is_first_multi_step = is_first_multi_step + model_input.is_last_step = execute_model_req.is_last_step + + if not is_first_multi_step: + # we broadcast the last sampled token ids to all TP workers so they + # can update their model input metadata in-place. + self._prepare_last_sampled_token_ids_for_tp_workers( + execute_model_req=execute_model_req, model_input=model_input) + + if self.do_metadata_broadcast: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update(model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + + return model_input, worker_input + + def _prepare_last_sampled_token_ids_for_tp_workers( + self, + execute_model_req: ExecuteModelRequest, + model_input: StatefulModelInput, + ) -> None: + """ + Prepare the last sampled token ids for TP workers. If it's the last + PP rank, then the last sampled token ids are already in the model_input. + If it is NOT the last PP rank, then we need to get the last sampled + token that is cached in the execute_model_req. + """ + if get_pp_group().is_last_rank: + assert model_input.cached_outputs[ + -1].sampler_output.sampled_token_ids is None + assert model_input.cached_outputs[-1].sampled_token_ids is not None + model_input.last_sampled_token_ids = model_input.cached_outputs[ + -1].sampled_token_ids + # free sampled token ids from the previous step if it has been + # pythonized. Cannot free the last sampled token ids because + # we need it for GPU advance_step. + for output in model_input.cached_outputs[:-1]: + if output.pythonized: + output.sampled_token_ids = None + else: + # otherwise we need to get the cached sampled token ids from the + # execute_model_req + assert execute_model_req.last_sampled_token_ids is not None + model_input.last_sampled_token_ids = ( + execute_model_req.last_sampled_token_ids.cuda()) + model_input.add_sampler_output( + SamplerOutput(outputs=[], sampled_token_ids=None), + model_input.last_sampled_token_ids) + + # free sampled token ids from the previous step. + # TODO(will) we could reuse the sampled token ids tensor from + # the previous step instead. + for output in model_input.cached_outputs[:-1]: + output.sampled_token_ids = None + assert model_input.cached_outputs[-1].sampled_token_ids is not None + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[Tuple[StatefulModelInput, WorkerInput]]: + """ + Depending on the current state of the request and multi step worker, + this method may skip the normal _prepare_model_input and + _prepare_worker_input methods and instead used cached values. + """ + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) + return None + + virtual_engine = execute_model_req.virtual_engine + model_input, worker_input = self._get_driver_input_and_broadcast( + execute_model_req) + assert isinstance(model_input, StatefulModelInput) + if execute_model_req.is_first_multi_step: + # cache the worker input and model input for the next steps + self.multi_step_states[virtual_engine] = MultiStepState( + worker_input=worker_input, model_input=model_input) + # if TP workers + else: + broadcast_data = self._get_worker_input_from_broadcast() + # if the driver has sent an empty input, we should stop the worker + # loop + if broadcast_data is None: + return None + model_input, worker_input = broadcast_data + assert isinstance(model_input, StatefulModelInput) + virtual_engine = worker_input.virtual_engine + if model_input.is_first_multi_step: + pass + # TODO(will) Can cache the worker input and model input for the + # next steps. See below for details + else: + # TODO(will) possible to also cache and reuse the cached worker + # input and model input. The idea is essentially the delta + # optimization for model_inputs. Where the TP workers can cache + # the model input states and we only broadcast the delta need + # for the next step (sampled_token_ids from the previous step) + + assert isinstance(model_input, StatefulModelInput) + # we need to update the last sampled token ids in the model + # input for the workers so that they can run inplace + # advance_step + model_input.add_sampler_output( + SamplerOutput(outputs=[], sampled_token_ids=None), + model_input.last_sampled_token_ids) + + assert model_input is not None + assert worker_input is not None + return model_input, worker_input diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 905052d1a9515..9fddc863548eb 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -16,7 +16,9 @@ SamplerOutput) from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) -from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase +from vllm.worker.model_runner_base import (BroadcastableModelInput, + ModelRunnerBase, + ModelRunnerInputBase) logger = init_logger(__name__) @@ -220,7 +222,7 @@ def execute_worker(self, worker_input: WorkerInput) -> None: raise NotImplementedError def _get_worker_input_from_broadcast( - self) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]: + self) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]: """ Get the worker input from the broadcasted tensor dict. """ assert self.do_metadata_broadcast assert not self.is_driver_worker @@ -237,7 +239,7 @@ def _get_worker_input_from_broadcast( def _get_driver_input_and_broadcast( self, execute_model_req: ExecuteModelRequest - ) -> Tuple[ModelRunnerInputBase, WorkerInput]: + ) -> Tuple[BroadcastableModelInput, WorkerInput]: """ Get the driver input and broadcast it to other workers. """ assert self.is_driver_worker @@ -259,7 +261,7 @@ def _get_driver_input_and_broadcast( def prepare_input( self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]: + ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]: """ Prepare the inputs to ModelRunner and workers. """ From 7601cb044ddfe920055f82ae9503729d4dde7259 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 20 Aug 2024 05:30:14 +0800 Subject: [PATCH 06/15] [Core] Support tensor parallelism for GGUF quantization (#7520) --- tests/models/test_gguf.py | 24 +++++++++++++---- vllm/model_executor/layers/linear.py | 26 ++++++++++++++----- .../layers/quantization/gguf.py | 4 --- 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/tests/models/test_gguf.py b/tests/models/test_gguf.py index 5971179f01211..196cd88e039a1 100644 --- a/tests/models/test_gguf.py +++ b/tests/models/test_gguf.py @@ -7,6 +7,7 @@ import pytest from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer from tests.quantization.utils import is_quant_method_supported @@ -20,7 +21,7 @@ MODELS = [ ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", - filename="tinyllama-1.1b-chat-v1.0.Q4_0.gguf")), + filename="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")), ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF", filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")), @@ -39,22 +40,36 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tp_size", [1, 2]) def test_models( + num_gpus_available, vllm_runner, example_prompts, model, dtype: str, max_tokens: int, num_logprobs: int, + tp_size: int, ) -> None: + if num_gpus_available < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + original_model, gguf_model = model + tokenizer = AutoTokenizer.from_pretrained(original_model) + messages = [[{ + 'role': 'user', + 'content': prompt + }] for prompt in example_prompts] + example_prompts = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + # Run unquantized model. with vllm_runner(model_name=original_model, dtype=dtype, max_model_len=MAX_MODEL_LEN, - enforce_eager=True, - tensor_parallel_size=1) as original_model: + tensor_parallel_size=tp_size) as original_model: original_outputs = original_model.generate_greedy_logprobs( example_prompts[:-1], max_tokens, num_logprobs) @@ -63,8 +78,7 @@ def test_models( with vllm_runner(model_name=gguf_model, dtype=dtype, max_model_len=MAX_MODEL_LEN, - enforce_eager=True, - tensor_parallel_size=1) as gguf_model: + tensor_parallel_size=tp_size) as gguf_model: gguf_outputs = gguf_model.generate_greedy_logprobs( example_prompts[:-1], max_tokens, num_logprobs) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b4cc6daa3c41e..3824ed3570aeb 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -507,11 +507,16 @@ def weight_loader(self, loaded_shard_id if is_gguf_weight: - shard_size = loaded_weight.shape[output_dim] - shard_offset = loaded_weight.shape[output_dim] * \ - loaded_shard_id + tp_size = get_tensor_model_parallel_world_size() + output_dim = getattr(param, "output_dim", None) + shard_shape = list(loaded_weight.shape) + shard_shape[output_dim] = shard_shape[output_dim] // tp_size param.shard_id.append(loaded_shard_id) - param.shard_size[loaded_shard_id] = loaded_weight.shape + param.shard_size[loaded_shard_id] = shard_shape + + input_dim = getattr(param, "input_dim", None) + input_size = loaded_weight.shape[input_dim] + param_data = param_data.narrow(input_dim, 0, input_size) param_data = param_data.narrow(output_dim, shard_offset, shard_size) @@ -863,8 +868,13 @@ def weight_loader(self, param, orig_qkv_offsets, loaded_shard_id) if is_gguf_weight: + tp_size = get_tensor_model_parallel_world_size() + output_dim = getattr(param, "output_dim", None) + shard_shape = list(loaded_weight.shape) + shard_shape[output_dim] = shard_shape[output_dim] // tp_size param.shard_id.append(loaded_shard_id) - param.shard_size[loaded_shard_id] = loaded_weight.shape + param.shard_size[loaded_shard_id] = shard_shape + input_dim = getattr(param, "input_dim", None) input_size = loaded_weight.shape[input_dim] param_data = param_data.narrow(input_dim, 0, input_size) @@ -976,6 +986,7 @@ def __init__(self, def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() input_dim = getattr(param, "input_dim", None) # Special case for GGUF @@ -986,7 +997,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # Materialize GGUF UninitializedParameter if is_gguf_weight and isinstance(param, UninitializedParameter): - param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + weight_shape = list(loaded_weight.shape) + if input_dim: + weight_shape[input_dim] = weight_shape[input_dim] // tp_size + param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data if input_dim is not None: diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index a4e0a4d509608..a6a1ed5b0dee5 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -5,7 +5,6 @@ from torch.nn.parameter import Parameter, UninitializedParameter from vllm import _custom_ops as ops -from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -39,9 +38,6 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig": - if get_tensor_model_parallel_world_size() > 1: - raise ValueError( - "GGUF quantization hasn't supported tensor parallelism yet.") return cls() def get_quant_method(self, layer: torch.nn.Module, From da115230fdde197abce793288b80da5223902861 Mon Sep 17 00:00:00 2001 From: Andrew Song <40076917+a-ys@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:11:58 -0700 Subject: [PATCH 07/15] [Bugfix] Don't disable existing loggers (#7664) --- vllm/logger.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/logger.py b/vllm/logger.py index 3c6bf0803a624..77dddbfb60965 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -43,6 +43,7 @@ }, }, "version": 1, + "disable_existing_loggers": False } From 43735bf5e19eaf243b6edaa5af4c7561a14fc2f6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 19 Aug 2024 15:55:04 -0700 Subject: [PATCH 08/15] [TPU] Remove redundant input tensor cloning (#7660) --- vllm/worker/tpu_model_runner.py | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 14f14e40b4c0b..01daa64b5a32f 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -516,27 +516,19 @@ def execute_model( raise ValueError( "TPUModelRunner does not support multi-step execution.") - def _execute_model(*args, clone: bool = False) -> torch.Tensor: + def _execute_model(*args): """Move input args from CPU to device and execute the model.""" - def _copy_to_device(x: torch.Tensor) -> torch.Tensor: - if clone: - # When x is a slice of a CPU tensor, XLA may copy the whole - # original tensor to TPU instead of only copying x. - # To avoid this, we copy x after cloning. - x = x.clone() - return x.to(self.device) - new_args = [] for arg in args: if isinstance(arg, torch.Tensor): - arg = _copy_to_device(arg) + arg = arg.to(self.device) elif isinstance(arg, AttentionMetadata): - arg.slot_mapping = _copy_to_device(arg.slot_mapping) + arg.slot_mapping = arg.slot_mapping.to(self.device) if getattr(arg, "block_tables", None) is not None: - arg.block_tables = _copy_to_device(arg.block_tables) + arg.block_tables = arg.block_tables.to(self.device) if getattr(arg, "context_lens", None) is not None: - arg.context_lens = _copy_to_device(arg.context_lens) + arg.context_lens = arg.context_lens.to(self.device) new_args.append(arg) return self.model(*new_args) @@ -563,13 +555,9 @@ def _copy_to_device(x: torch.Tensor) -> torch.Tensor: output_token_ids = _execute_model( model_input.token_ids[None, start_idx:end_idx], model_input.position_ids[None, start_idx:end_idx], - model_input.attn_metadata, - model_input.input_lens[i:i + 1], - model_input.t[i:i + 1], - model_input.p[i:i + 1], - model_input.num_samples, - kv_caches, - clone=True) + model_input.attn_metadata, model_input.input_lens[i:i + 1], + model_input.t[i:i + 1], model_input.p[i:i + 1], + model_input.num_samples, kv_caches) # Retrieve the outputs to CPU. next_token_ids += output_token_ids.cpu().tolist() start_idx = end_idx From 67e02fa8a405e1e1df0eb7428ad45eed20b0934b Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Mon, 19 Aug 2024 18:43:09 -0600 Subject: [PATCH 09/15] [Bugfix] use StoreBoolean instead of type=bool for --disable-logprobs-during-spec-decoding (#7665) Signed-off-by: Travis Johnson --- vllm/engine/arg_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a3bb87bbe6748..7f45c3d06375a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -664,8 +664,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--disable-logprobs-during-spec-decoding', - type=bool, + action=StoreBoolean, default=EngineArgs.disable_logprobs_during_spec_decoding, + nargs="?", + const="True", help='If set to True, token log probabilities are not returned ' 'during speculative decoding. If set to False, log probabilities ' 'are returned according to the settings in SamplingParams. If ' From e54ebc2f8f9d78f3113fb2b531058977e6031609 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 19 Aug 2024 17:50:59 -0700 Subject: [PATCH 10/15] [doc] fix doc build error caused by msgspec (#7659) --- docs/requirements-docs.txt | 1 + vllm/platforms/__init__.py | 47 +++++++++++++++++++++++++++++++------- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 9a5964ec65b99..6a8d99635b8f0 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -3,6 +3,7 @@ sphinx-book-theme==1.0.1 sphinx-copybutton==0.5.2 myst-parser==2.0.0 sphinx-argparse==0.4.0 +msgspec # packages to install to build the documentation pydantic diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 99ba940e5d2ab..958f6c516a2f8 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -1,23 +1,54 @@ -import torch - from .interface import Platform, PlatformEnum, UnspecifiedPlatform current_platform: Platform +# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because +# they only indicate the build configuration, not the runtime environment. +# For example, people can install a cuda build of pytorch but run on tpu. + +is_tpu = False +try: + import torch_xla.core.xla_model as xm + xm.xla_device(devkind="TPU") + is_tpu = True +except Exception: + pass + +is_cuda = False + +try: + import pynvml + pynvml.nvmlInit() + try: + if pynvml.nvmlDeviceGetCount() > 0: + is_cuda = True + finally: + pynvml.nvmlShutdown() +except Exception: + pass + +is_rocm = False + try: - import libtpu -except ImportError: - libtpu = None + import amdsmi + amdsmi.amdsmi_init() + try: + if len(amdsmi.amdsmi_get_processor_handles()) > 0: + is_rocm = True + finally: + amdsmi.amdsmi_shut_down() +except Exception: + pass -if libtpu is not None: +if is_tpu: # people might install pytorch built with cuda but run on tpu # so we need to check tpu first from .tpu import TpuPlatform current_platform = TpuPlatform() -elif torch.version.cuda is not None: +elif is_cuda: from .cuda import CudaPlatform current_platform = CudaPlatform() -elif torch.version.hip is not None: +elif is_rocm: from .rocm import RocmPlatform current_platform = RocmPlatform() else: From 312f7612328fb8db9fb2dd9f94f2ea021c035357 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Tue, 20 Aug 2024 06:28:14 +0530 Subject: [PATCH 11/15] [Speculative Decoding] Fixing hidden states handling in batch expansion (#7508) --- tests/spec_decode/e2e/conftest.py | 25 ++++-- tests/spec_decode/e2e/test_mlp_correctness.py | 42 +++++++++ vllm/spec_decode/batch_expansion.py | 86 +++++++++++++------ vllm/spec_decode/spec_decode_worker.py | 5 +- vllm/spec_decode/top1_proposer.py | 2 +- vllm/spec_decode/util.py | 20 ++++- 6 files changed, 139 insertions(+), 41 deletions(-) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index d0f91a63b2d6a..a701f482b4ffb 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -288,15 +288,17 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, ensure_all_accepted=ensure_all_accepted) -def run_equality_correctness_test(baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len, - force_output_len: bool, - temperature: float, - seeded: bool, - print_tokens: bool = False, - ensure_all_accepted: bool = False): +def run_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + temperature: float, + seeded: bool, + print_tokens: bool = False, + ensure_all_accepted: bool = False, + expected_acceptance_rate: Optional[float] = None): """Helper method that compares the outputs of both the baseline LLM and the test LLM. It asserts greedy equality, e.g. that the outputs are exactly the same when temperature is zero (or when temperature is > 0 and seeded). @@ -357,5 +359,10 @@ def run_equality_correctness_test(baseline_llm_generator, print(f'{i=} {spec_token_ids=}') assert baseline_token_ids == spec_token_ids + print(f'{acceptance_rate=}') + if ensure_all_accepted: assert acceptance_rate == 1.0 + + if expected_acceptance_rate is not None: + assert acceptance_rate >= expected_acceptance_rate - 1e-2 diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 25067e7a4262c..c72e4595fd335 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -82,6 +82,48 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, force_output_len=True) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + }, +]) +@pytest.mark.parametrize("output_len", [2048]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify acceptance rate with different batch size and large output + length.""" + run_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + temperature=0.0, + seeded=True, + force_output_len=True, + expected_acceptance_rate=0.48) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index aec4847b96c35..ad6f3f313841d 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -1,6 +1,6 @@ from array import array from itertools import chain, count -from typing import Iterator, List, Tuple +from typing import Iterator, List, Optional, Tuple import torch @@ -88,21 +88,22 @@ def score_proposals( assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] - all_tokens, all_probs, spec_logprobs = self._contract_batch( - contracted_bs=len(execute_model_req.seq_group_metadata_list), - target_sampler_output=target_sampler_output, - proposals=proposals, - num_scoring_tokens=num_scoring_tokens, - non_spec_indices=non_spec_indices, - spec_indices=spec_indices, - k=execute_model_req.num_lookahead_slots, - ) + (all_tokens, all_probs, spec_logprobs, + all_hidden_states) = self._contract_batch( + contracted_bs=len(execute_model_req.seq_group_metadata_list), + target_sampler_output=target_sampler_output, + proposals=proposals, + num_scoring_tokens=num_scoring_tokens, + non_spec_indices=non_spec_indices, + spec_indices=spec_indices, + k=execute_model_req.num_lookahead_slots, + ) return SpeculativeScores( probs=all_probs, token_ids=all_tokens, logprobs=spec_logprobs, - hidden_states=target_sampler_output.hidden_states, + hidden_states=all_hidden_states, ) def _expand_batch( @@ -145,10 +146,11 @@ def _expand_batch( num_scoring_tokens) def _contract_batch( - self, contracted_bs: int, target_sampler_output: SamplerOutput, - proposals: SpeculativeProposals, num_scoring_tokens: int, - non_spec_indices: List[int], spec_indices: List[int], - k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + self, contracted_bs: int, target_sampler_output: SamplerOutput, + proposals: SpeculativeProposals, num_scoring_tokens: int, + non_spec_indices: List[int], spec_indices: List[int], k: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: """Contract the expanded batch back into its original size. This maps the scores of speculative tokens back to their original sequences. @@ -156,9 +158,10 @@ def _contract_batch( contracted_bs is the original batch size, and the batch size that the target_sampler_output will be contracted to. """ - (target_token_ids, target_probs, target_logprobs, + (target_token_ids, target_probs, target_logprobs, target_hidden_states, non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) = self._split_scoring_output( + non_spec_target_logprobs, + non_spec_target_hidden_states) = self._split_scoring_output( target_sampler_output, num_scoring_tokens) # Map distinct sequences used to score each token @@ -176,23 +179,40 @@ def _contract_batch( self._vocab_size) target_logprobs = target_logprobs.reshape(target_probs.shape) + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.reshape( + spec_expanded_bs, k + 1, target_hidden_states.shape[-1]) + all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), fill_value=-1) all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size) all_logprobs = target_logprobs.new_full(size=all_probs.shape, fill_value=-float("inf")) + if target_sampler_output.hidden_states is not None: + all_hidden_states = target_hidden_states.new_zeros( + size=(contracted_bs, k + 1, target_hidden_states.shape[-1])) + else: + all_hidden_states = None + if non_spec_indices: all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_probs[non_spec_indices, :1, :] = non_spec_target_probs all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs + if all_hidden_states is not None: + all_hidden_states[ + non_spec_indices, :1, :] = non_spec_target_hidden_states + if spec_indices: all_tokens[spec_indices] = target_token_ids all_probs[spec_indices] = target_probs all_logprobs[spec_indices] = target_logprobs - return all_tokens, all_probs, all_logprobs + if all_hidden_states is not None: + all_hidden_states[spec_indices] = target_hidden_states + + return all_tokens, all_probs, all_logprobs, all_hidden_states def _create_scoring_model_input( self, @@ -327,8 +347,9 @@ def _create_single_target_seq_group_metadata( def _split_scoring_output( self, sampler_output: SamplerOutput, num_scoring_tokens: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor, Optional[torch.Tensor]]: """Split the target model output into speculative and non-speculative output. """ @@ -353,24 +374,37 @@ def _split_scoring_output( non_spec_logprobs, ) = sampler_output.logprobs.split(split_sizes) + if sampler_output.hidden_states is not None: + ( + spec_hidden_states, + non_spec_hidden_states, + ) = sampler_output.hidden_states.split(split_sizes) + else: + spec_hidden_states, non_spec_hidden_states = None, None + # Convert scores to tensors. sampler_output.sampled_token_probs = spec_probs sampler_output.sampled_token_ids = spec_sampled_tokens sampler_output.logprobs = spec_logprobs - (target_token_ids, target_probs, - target_logprobs) = sampler_output_to_torch([sampler_output], True) + sampler_output.hidden_states = spec_hidden_states + (target_token_ids, target_probs, target_logprobs, + target_hidden_states) = sampler_output_to_torch([sampler_output], + True) # Convert non-speculative output tokens to tensors. sampler_output.sampled_token_probs = non_spec_probs sampler_output.sampled_token_ids = non_spec_sampled_tokens sampler_output.logprobs = non_spec_logprobs + sampler_output.hidden_states = non_spec_hidden_states (non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) = sampler_output_to_torch([sampler_output], - True) + non_spec_target_logprobs, + non_spec_target_hidden_states) = sampler_output_to_torch( + [sampler_output], True) return (target_token_ids, target_probs, target_logprobs, - non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) + target_hidden_states, non_spec_target_token_ids, + non_spec_target_probs, non_spec_target_logprobs, + non_spec_target_hidden_states) def _create_target_seq_id_iterator( self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 63a00139cc09d..acf77a7349eef 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -646,9 +646,8 @@ def _verify_tokens( hidden_states = proposal_scores.hidden_states if hidden_states is not None: # Contract hidden states based on accepted tokens - hs_size = hidden_states.shape[1] - hidden_states = hidden_states.reshape(-1, max_proposal_len + 1, - hs_size) + hs_size = hidden_states.shape[-1] + accepted_index = accepted_token_ids + 1 # Convert -1 to 0 accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) index = accepted_index[:, None, None].expand(-1, 1, hs_size) diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 1a56497030280..28f7f7eb069ab 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -242,7 +242,7 @@ def _merge_outputs( return proposal_tokens, proposal_probs, proposal_lens_tensor sampler_output = maybe_sampler_output - proposal_tokens, proposal_probs, _ = sampler_output_to_torch( + proposal_tokens, proposal_probs, *_ = sampler_output_to_torch( sampler_output, sampler_transposed) # Now, reformat the output GPU tensors such that each sequence has diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index c6223a97dba10..b85f2a6f70ac0 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -123,7 +123,7 @@ def split_batch_by_proposal_len( def sampler_output_to_torch( sampler_output_list: List[SamplerOutput], sampler_transposed: bool -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Utility function which converts a list of SamplerOutput to tensors. sampler_transposed here is used as the indicator for whether @@ -169,7 +169,23 @@ def sampler_output_to_torch( if sampler_transposed: sampled_token_ids = sampled_token_ids.transpose(0, 1) - return sampled_token_ids, sampled_token_probs, sampled_token_logprobs + if sampler_output_list[0].hidden_states is not None: + # shape: [batch_size, num_sampler_output, hidden_dim] + sampled_hidden_states = torch.stack( + [ + sampler_output.hidden_states + for sampler_output in sampler_output_list + ], + dim=0, + ) + + if sampler_transposed: + sampled_hidden_states = sampled_hidden_states.transpose(0, 1) + else: + sampled_hidden_states = None + + return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs, + sampled_hidden_states) def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, From 0df7ec0b2d890799ca71e2f862fdff5fcc52cdc0 Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Mon, 19 Aug 2024 19:55:04 -0700 Subject: [PATCH 12/15] [ci] Install Buildkite test suite analysis (#7667) Signed-off-by: kevin --- requirements-test.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-test.txt b/requirements-test.txt index 95909d37e2c94..cdbc3e50cc9ec 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -29,4 +29,5 @@ matplotlib # required for qwen-vl test aiohttp # quantization -bitsandbytes==0.42.0 \ No newline at end of file +bitsandbytes==0.42.0 +buildkite-test-collector==0.1.8 \ No newline at end of file From f4fc7337bfaf5f10b8da4ba547e4009179348a26 Mon Sep 17 00:00:00 2001 From: Zijian Hu Date: Mon, 19 Aug 2024 20:00:04 -0700 Subject: [PATCH 13/15] [Bugfix] support `tie_word_embeddings` for all models (#5724) --- vllm/model_executor/models/arctic.py | 2 ++ vllm/model_executor/models/baichuan.py | 2 ++ vllm/model_executor/models/bart.py | 2 ++ vllm/model_executor/models/blip2.py | 3 +++ vllm/model_executor/models/bloom.py | 9 +++++++-- vllm/model_executor/models/chatglm.py | 3 +++ vllm/model_executor/models/commandr.py | 3 +++ vllm/model_executor/models/dbrx.py | 3 +++ vllm/model_executor/models/deepseek.py | 2 ++ vllm/model_executor/models/gemma.py | 2 ++ vllm/model_executor/models/gemma2.py | 2 ++ vllm/model_executor/models/gpt2.py | 8 ++++++-- vllm/model_executor/models/gpt_bigcode.py | 10 ++++++++-- vllm/model_executor/models/gpt_neox.py | 4 +++- vllm/model_executor/models/internlm2.py | 2 ++ vllm/model_executor/models/jais.py | 8 ++++++-- vllm/model_executor/models/llava.py | 4 ++-- vllm/model_executor/models/llava_next.py | 4 ++-- vllm/model_executor/models/minicpmv.py | 4 ++++ vllm/model_executor/models/mixtral.py | 2 ++ vllm/model_executor/models/mixtral_quant.py | 2 ++ vllm/model_executor/models/opt.py | 8 ++++++-- vllm/model_executor/models/orion.py | 2 ++ vllm/model_executor/models/phi.py | 2 ++ vllm/model_executor/models/phi3_small.py | 3 ++- vllm/model_executor/models/phi3v.py | 2 ++ vllm/model_executor/models/qwen.py | 2 ++ vllm/model_executor/models/qwen2_moe.py | 2 ++ vllm/model_executor/models/stablelm.py | 2 ++ vllm/model_executor/models/xverse.py | 2 ++ 30 files changed, 90 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 74e534aa76a9d..28f69cfbc46bd 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -414,6 +414,8 @@ def __init__(self, config.hidden_size, quant_config=quant_config, ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok self.unpadded_vocab_size = config.vocab_size diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index a11c7663263c6..73711d8eb5185 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -331,6 +331,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index ef988532ce126..f78400b0df7b3 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -821,6 +821,8 @@ def __init__(self, lora_config: Optional[LoRAConfig] = None): super().__init__() + # currently all existing BART models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings self.config = config self.model = BartModel(config, cache_config, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 8cfd3c2672568..20dda2a67820d 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -494,6 +494,9 @@ def __init__(self, super().__init__() + # currently all existing BLIP-2 models have `tie_word_embeddings` + # enabled + assert config.tie_word_embeddings self.config = config self.multimodal_config = multimodal_config diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 282a0f84eacb1..07ee0e3c531d0 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -36,7 +36,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -276,7 +276,12 @@ def __init__( self.config = config self.quant_config = quant_config self.transformer = BloomModel(config, cache_config, quant_config) - self.lm_head = self.transformer.word_embeddings + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.word_embeddings + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) + self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index b29ebe2f59e7b..4949d0232fabb 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -356,6 +356,9 @@ def __init__( self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) self.transformer = ChatGLMModel(config, cache_config, quant_config) + if self.config.tie_word_embeddings: + self.transformer.output_layer.weight = ( + self.transformer.embedding.weight) self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 0894f750e5fbf..f63cf246e510a 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -321,6 +321,9 @@ def __init__( ) -> None: super().__init__() self.config = config + # currently all existing command R models have `tie_word_embeddings` + # enabled + assert config.tie_word_embeddings self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 7ebeca1a359ef..dca959798e8b2 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -362,6 +362,9 @@ def __init__( ): super().__init__() self.config = config + if config.tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Dbrx models.") self.quant_config = quant_config self.unpadded_vocab_size = config.vocab_size self.transformer = DbrxModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index f10977ed2c90d..7a27e1388e987 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -380,6 +380,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 7a9ee3d9477ca..e1041edf81b0a 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -331,6 +331,8 @@ def __init__( super().__init__() self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings self.lora_config = lora_config self.quant_config = quant_config diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index ff547c2c3b8ab..5e0f8b70d4b80 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -323,6 +323,8 @@ def __init__( del lora_config # Unused. super().__init__() self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings self.quant_config = quant_config self.model = Gemma2Model(config, cache_config, quant_config) self.logits_processor = LogitsProcessor( diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 4f2fe0c42a3ff..bfc231282952a 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -36,7 +36,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -249,7 +249,11 @@ def __init__( cache_config, quant_config, prefix="transformer") - self.lm_head = self.transformer.wte + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.wte + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index b30af3599aa4d..b93fb8d69b2d7 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -36,7 +36,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -259,7 +259,13 @@ def __init__( self.quant_config = quant_config self.transformer = GPTBigCodeModel(config, cache_config, quant_config, lora_config) - self.lm_head = self.transformer.wte + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.wte + else: + self.lm_head = ParallelLMHead( + self.transformer.vocab_size, + self.transformer.embed_dim, + org_num_embeddings=self.config.vocab_size) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index e61b4448981e8..2adecf7fa9ef8 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -230,7 +230,7 @@ class GPTNeoXForCausalLM(nn.Module): def __init__( self, - config, + config: GPTNeoXConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): @@ -243,6 +243,8 @@ def __init__( config.hidden_size, quant_config=quant_config, ) + if self.config.tie_word_embeddings: + self.embed_out.weight = self.gpt_neox.embed_in.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 216458465513a..887a353df972c 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -264,6 +264,8 @@ def __init__( self.output = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.output.weight = self.model.tok_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index ec6bea920cc3a..a550f7e6c97a1 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -37,7 +37,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -291,7 +291,11 @@ def __init__( self.config = config self.quant_config = quant_config self.transformer = JAISModel(config, cache_config, quant_config) - self.lm_head = self.transformer.wte + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.wte + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale else: diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 46db364895b13..6433ea380cbfe 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -313,7 +313,7 @@ def forward( 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`. To reserve space in KV cache, we have to insert placeholder tokens - before they are inputted to the model, so the input processor prepends + before they are inputted to the model, so the input processor prepends additional image tokens (denoted as `32000`), resulting in: `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, @@ -331,7 +331,7 @@ def forward( input_ids: Flattened (concatenated) input_ids corresponding to a batch. pixel_values: The pixels in each input image. - + See also: :class:`LlavaImageInputs` """ diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c1277359182e4..c7cb243fa84da 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -545,7 +545,7 @@ def forward( 9047, 13566, 29901]`. To reserve space in KV cache, we have to insert placeholder tokens - before they are inputted to the model, so the input processor prepends + before they are inputted to the model, so the input processor prepends additional image tokens (denoted as `32000`), resulting in: `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, @@ -566,7 +566,7 @@ def forward( batch. pixel_values: The pixels in each grid patch for each input image. image_sizes: The original `(height, width)` for each input image. - + See also: :class:`LlavaNextImageInputs` """ diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 729bd27c334d5..99a3c5dab39e4 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -496,6 +496,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ): super().__init__() + # All MiniCPM-V models disable `tie_word_embeddings` but + # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot + # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model + # and config class self.config = config self.multimodal_config = multimodal_config diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 587d2f26a2d5e..34f581ac78582 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -359,6 +359,8 @@ def __init__( if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 812dce5d04771..8bdd52b343175 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -347,6 +347,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index b05f799e4dd2b..c0d2d537e731f 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -36,7 +36,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -307,7 +307,11 @@ def __init__( self.config = config self.quant_config = quant_config self.model = OPTModel(config, cache_config, quant_config) - self.lm_head = self.model.decoder.embed_tokens + if self.config.tie_word_embeddings: + self.lm_head = self.model.decoder.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.word_embed_proj_dim) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 6923e11e288be..fab35f0b882a7 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -262,6 +262,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 54f4dd2fcde0a..f31b5162aac96 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -260,6 +260,8 @@ def __init__( super().__init__() self.config = config + # lm_head use bias, cannot share word embeddings + assert not config.tie_word_embeddings self.lora_config = lora_config self.quant_config = quant_config diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 98e344d483e29..df01bfa3d8e6e 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -368,6 +368,8 @@ def __init__( padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -449,4 +451,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 1c8bb8a837c86..328f4e6fa827c 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -477,6 +477,8 @@ def __init__(self, self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index a7485bcb489a0..b7d017d5f3ea6 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -252,6 +252,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index e160c9a320820..6f838947fbf27 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -385,6 +385,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index c98226d61a8a0..decbf89d27c7c 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -243,6 +243,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index e9bf67d314d0a..c0bafa9367e43 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -313,6 +313,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() From 3d8a5f063d8a96ccfb8fc14d1d43b93cea0411a0 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 19 Aug 2024 22:43:54 -0700 Subject: [PATCH 14/15] [CI] Organizing performance benchmark files (#7616) --- .buildkite/nightly-benchmarks/README.md | 9 ++-- .../benchmark-pipeline.yaml | 2 +- ...=> performance-benchmarks-descriptions.md} | 0 .../convert-results-json-to-markdown.py | 4 +- .../run-performance-benchmarks.sh} | 45 ++++++++++++------- 5 files changed, 36 insertions(+), 24 deletions(-) rename .buildkite/nightly-benchmarks/{tests/descriptions.md => performance-benchmarks-descriptions.md} (100%) rename .buildkite/nightly-benchmarks/{run-benchmarks-suite.sh => scripts/run-performance-benchmarks.sh} (90%) diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md index c1aebaf5b3bbe..fbf41eb10a392 100644 --- a/.buildkite/nightly-benchmarks/README.md +++ b/.buildkite/nightly-benchmarks/README.md @@ -34,17 +34,18 @@ See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performan Performance benchmark will be triggered when: - A PR being merged into vllm. -- Every commit for those PRs with `perf-benchmarks` label. +- Every commit for those PRs with `perf-benchmarks` label AND `ready` label. Nightly benchmark will be triggered when: -- Every commit for those PRs with `nightly-benchmarks` label. +- Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label. ## Performance benchmark details -See [descriptions.md](tests/descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. + +See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. #### Latency test @@ -68,7 +69,7 @@ Here is an example of one test inside `latency-tests.json`: In this example: - The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`. -- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-benchmarks-suite.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` +- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-performance-benchmarks.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` Note that the performance numbers are highly sensitive to the value of the parameters. Please make sure the parameters are set correctly. diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 8490c9f1da221..2b70e2da5d87c 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -21,7 +21,7 @@ steps: containers: - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT command: - - bash .buildkite/nightly-benchmarks/run-benchmarks-suite.sh + - bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh resources: limits: nvidia.com/gpu: 8 diff --git a/.buildkite/nightly-benchmarks/tests/descriptions.md b/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md similarity index 100% rename from .buildkite/nightly-benchmarks/tests/descriptions.md rename to .buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index 534ecf17930e9..f90e464288cf1 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -174,8 +174,8 @@ def results_to_json(latency, throughput, serving): # document the result with open(results_folder / "benchmark_results.md", "w") as f: - results = read_markdown( - "../.buildkite/nightly-benchmarks/tests/descriptions.md") + results = read_markdown("../.buildkite/nightly-benchmarks/" + + "performance-benchmarks-descriptions.md") results = results.format( latency_tests_markdown_table=latency_md_table, throughput_tests_markdown_table=throughput_md_table, diff --git a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh similarity index 90% rename from .buildkite/nightly-benchmarks/run-benchmarks-suite.sh rename to .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh index f6e41fcfdd0be..a0b9a409b758d 100644 --- a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh +++ b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh @@ -37,9 +37,9 @@ check_hf_token() { ensure_sharegpt_downloaded() { local FILE=ShareGPT_V3_unfiltered_cleaned_split.json if [ ! -f "$FILE" ]; then - wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE + wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE else - echo "$FILE already exists." + echo "$FILE already exists." fi } @@ -68,11 +68,29 @@ wait_for_server() { done' && return 0 || return 1 } +kill_processes_launched_by_current_bash() { + # Kill all python processes launched from current bash script + current_shell_pid=$$ + processes=$(ps -eo pid,ppid,command | awk -v ppid="$current_shell_pid" -v proc="$1" '$2 == ppid && $3 ~ proc {print $1}') + if [ -n "$processes" ]; then + echo "Killing the following processes matching '$1':" + echo "$processes" + echo "$processes" | xargs kill -9 + else + echo "No processes found matching '$1'." + fi +} + kill_gpu_processes() { - # kill all processes on GPU. - ps aux | grep python | grep openai | awk '{print $2}' | xargs -r kill -9 - ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 + ps -aux + lsof -t -i:8000 | xargs -r kill -9 + pkill -f pt_main_thread + # this line doesn't work now + # ps aux | grep python | grep openai | awk '{print $2}' | xargs -r kill -9 + pkill -f python3 + pkill -f /usr/bin/python3 + # wait until GPU memory usage smaller than 1GB while [ $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1) -ge 1000 ]; do @@ -82,11 +100,6 @@ kill_gpu_processes() { # remove vllm config file rm -rf ~/.config/vllm - # Print the GPU memory usage - # so that we know if all GPU processes are killed. - gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) - # The memory usage should be 0 MB. - echo "GPU 0 Memory Usage: $gpu_memory_usage MB" } upload_to_buildkite() { @@ -104,7 +117,7 @@ upload_to_buildkite() { fi # Use the determined command to annotate and upload artifacts - $BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" < $RESULTS_FOLDER/benchmark_results.md + $BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" <$RESULTS_FOLDER/benchmark_results.md $BUILDKITE_AGENT_COMMAND artifact upload "$RESULTS_FOLDER/*" } @@ -156,7 +169,7 @@ run_latency_tests() { latency_command: $latency, gpu_type: $gpu }') - echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands" + echo "$jq_output" >"$RESULTS_FOLDER/$test_name.commands" # run the benchmark eval "$latency_command" @@ -166,7 +179,6 @@ run_latency_tests() { done } - run_throughput_tests() { # run throughput tests using `benchmark_throughput.py` # $1: a json file specifying throughput test cases @@ -214,7 +226,7 @@ run_throughput_tests() { throughput_command: $command, gpu_type: $gpu }') - echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands" + echo "$jq_output" >"$RESULTS_FOLDER/$test_name.commands" # run the benchmark eval "$throughput_command" @@ -246,7 +258,6 @@ run_serving_tests() { continue fi - # get client and server arguments server_params=$(echo "$params" | jq -r '.server_parameters') client_params=$(echo "$params" | jq -r '.client_parameters') @@ -324,7 +335,7 @@ run_serving_tests() { client_command: $client, gpu_type: $gpu }') - echo "$jq_output" > "$RESULTS_FOLDER/${new_test_name}.commands" + echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" done @@ -341,6 +352,7 @@ main() { # dependencies (which wget && which curl) || (apt-get update && apt-get install -y wget curl) (which jq) || (apt-get update && apt-get -y install jq) + (which lsof) || (apt-get update && apt-get install -y lsof) # get the current IP address, required by benchmark_serving.py export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') @@ -359,7 +371,6 @@ main() { run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json - # postprocess benchmarking results pip install tabulate pandas python3 $QUICK_BENCHMARK_ROOT/scripts/convert-results-json-to-markdown.py From c4be16e1a70d50b781038990087a717fd8834d4a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 19 Aug 2024 23:22:49 -0700 Subject: [PATCH 15/15] [misc] add nvidia related library in collect env (#7674) --- collect_env.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/collect_env.py b/collect_env.py index 76df97b099b1b..839d54172e775 100644 --- a/collect_env.py +++ b/collect_env.py @@ -66,6 +66,8 @@ "nccl", "transformers", "zmq", + "nvidia", + "pynvml", } DEFAULT_PIP_PATTERNS = { @@ -79,6 +81,8 @@ "nccl", "transformers", "zmq", + "nvidia", + "pynvml", }