Skip to content

Commit

Permalink
Merge pull request vllm-project#7 from KuntaiDu/jiayi-dev-v2
Browse files Browse the repository at this point in the history
Add support for LMCache
  • Loading branch information
KuntaiDu authored Sep 15, 2024
2 parents 80b4200 + 01fe335 commit 1f47731
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 60 deletions.
5 changes: 1 addition & 4 deletions tests/kv_transfer/test_lookup_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ def test_run(my_rank, buffer, device):
placeholder = torch.tensor([1]).to(device)

buffer.insert(tokens, roi, key, value, placeholder)

#for i in range(2000):
# print("Here:", i)
# time.sleep(0.01)

torch.distributed.barrier()

# drop_select
Expand Down
15 changes: 11 additions & 4 deletions vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,14 @@ def send_tensor_wrapper(self, tensor):
with self.buffer_size_lock:
self.buffer_size = self.buffer_size - tensor_size
except Exception as e:
logger.error("Encountering exception in KV sending thread")
logger.error("%s", e)
logger.error("[rank%d]: Exception when trying to send %s, msg: %s",
torch.distributed.get_rank(),
str(tensor),
str(e))
import traceback
traceback.print_exc()



def block_if_full(self):
"""
Expand Down Expand Up @@ -279,10 +285,11 @@ def recv_tensor(self) -> Optional[torch.Tensor]:
try:
tensor = future.result()
except Exception as e:
# the underlying pipe is likely broken
logger.error("Encountering exception in KV receiving thread")
logger.error("%s", e)

#tensor = self._recv_impl()
# fault tolerance: if the pipe is broken, return None
return None

if tensor.numel() == 1 and tensor.item() == NONE_INT:
return None
Expand Down
111 changes: 64 additions & 47 deletions vllm/distributed/kv_transfer/vllm_adapter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
"""vLLM distributed KV cache transfer API.
These APIs are used in `vllm/worker/model_runner.py`.
These APIs are used in `vllm/worker/worker_base.py`.
Currently supporting TP and PP, but TP and PP must be the same.
Currently supporting TP. The TP between prefill and decode instance needs to be the same.
Workflow:
- In prefill instance, vLLM `insert` that buffers the KV cache into lookup buffer.
Workflow (disaggregated prefill)
- In prefill instance
- After prefill, vLLM `insert` its KV caches into a lookup buffer.
- The prefill instance will also open up a thread that listens to `drop_select` request.
- In decode instance
- vLLM first runs `drop_select` to send input tokens and a mask on input tokens to sender
- The prefill instance send back the matching KV caches
- vLLM then store the KV cache into paged memory.
- vLLM first runs `drop_select` to send input tokens and a mask on input tokens (we call it roi, region of interest) to prefill instance
- The prefill instance then respond to `drop_select` request by
- Finding a match in current lookup buffer.
- Clone and send the matched item out
- Delete the matched item in the lookup buffer to free up GPU memory.
- The decode vLLM then store the KV cache into paged memory.
"""
from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from collections import defaultdict, deque
Expand All @@ -30,7 +35,6 @@
from vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe import TorchDistributedPipe
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_kv_lookup_buffer import SimpleKVLookupBuffer

from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from copy import deepcopy

assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode", "lmcache"], \
Expand Down Expand Up @@ -66,57 +70,68 @@ def __init__(
group_ranks: List[List[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
# FIXME(Kuntai): remove this hardcoding
lookup_buffer_size: int = 1e10
):


# init pipe
self.device_pipe = TorchDistributedPipe(
group_ranks,
local_rank,
torch_distributed_backend,
)
self.cpu_pipe = TorchDistributedPipe(
group_ranks,
local_rank,
"gloo"
)

# init two pipes: one or send and one for recv
if IS_KV_PREFILL_INSTANCE or IS_LMCACHE_INSTANCE:
self.lookup_buffer_size = lookup_buffer_size

if IS_LMCACHE_INSTANCE:
# when vLLM is connected with LMCache
# it needs to both send and recv KV cache
self.send_pipe = TorchDistributedPipe(
group_ranks,
local_rank,
torch_distributed_backend,
)
self.recv_pipe = TorchDistributedPipe(
self.send_signal_pipe = TorchDistributedPipe(
group_ranks,
local_rank,
torch_distributed_backend,
"gloo",
)
elif IS_KV_DECODE_INSTANCE:
self.recv_pipe = TorchDistributedPipe(
group_ranks,
local_rank,
torch_distributed_backend,
)
self.send_pipe = TorchDistributedPipe(
self.recv_signal_pipe = TorchDistributedPipe(
group_ranks,
local_rank,
"gloo",
)
self.send_buffer = SimpleKVLookupBuffer(
self.send_signal_pipe,
self.send_pipe,
self.lookup_buffer_size)
self.recv_buffer = SimpleKVLookupBuffer(
self.recv_signal_pipe,
self.recv_pipe,
self.lookup_buffer_size)
else:
# when performing disaggregated prefill, only 1 pipe is needed
# at prefill instance this pipe is used for send KV cache
# at decode instance this pipe is used for recv KV cache
self.pipe = TorchDistributedPipe(
group_ranks,
local_rank,
torch_distributed_backend,
)

self.signal_pipe = TorchDistributedPipe(
group_ranks,
local_rank,
"gloo",
)
buffer = SimpleKVLookupBuffer(
self.signal_pipe,
self.pipe,
self.lookup_buffer_size)
self.send_buffer = buffer
self.recv_buffer = buffer

# FIXME(Jiayi): buffer initializtion should be adapted accordingly
# Signal pipe needs to be initialized on both vllm and lmc side

# init lookup buffer
# TODO: replace this 1e9 with a configurable parameter or a constant
self.buffer = SimpleKVLookupBuffer(self.cpu_pipe, self.device_pipe, 1e9 * 10)

def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: ModelInputForGPUWithSamplingMetadata,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors],
) -> None:
Expand Down Expand Up @@ -152,7 +167,7 @@ def send_kv_caches_and_hidden_states(

keys = torch.cat(keys, dim=0)
values = torch.cat(values, dim=0)
self.buffer.insert(
self.send_buffer.insert(
current_tokens,
torch.ones_like(current_tokens, dtype=bool),
keys,
Expand All @@ -167,10 +182,11 @@ def send_kv_caches_and_hidden_states(
def recv_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: ModelInputForGPUWithSamplingMetadata,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> List[Union[torch.Tensor, IntermediateTensors], bool, ModelInputForGPUWithSamplingMetadata]:
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]:

# When this flag is set to False, it means that
bypass_model_exec = True

# This is disagg decode instance, during prefill state
Expand All @@ -197,12 +213,12 @@ def recv_kv_caches_and_hidden_states(
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)

ret = self.buffer.drop_select(
ret = self.recv_buffer.drop_select(
current_tokens,
torch.ones_like(current_tokens, dtype=bool))
if ret[0] is None:
# didn't find any match.
self.bypass_model_exec = False
bypass_model_exec = False
num_computed_tokens_list.append(0)
continue

Expand All @@ -219,9 +235,7 @@ def recv_kv_caches_and_hidden_states(
for i in range(model_executable.model.start_layer,
model_executable.model.end_layer):

# get kv cache
kv_cache = kv_caches[i - model_executable.model.start_layer]
# get corresponding layer
layer = model_executable.model.layers[i]

key_cache, value_cache = kv_cache[0], kv_cache[1]
Expand All @@ -247,7 +261,7 @@ def recv_kv_caches_and_hidden_states(
return None, bypass_model_exec, None

if not is_complete:
rebuilt_model_input = self.adpat_model_input(
rebuilt_model_input = self.build_partial_prefill_input(
model_input,
input_tokens_list,
num_computed_tokens_list,
Expand All @@ -266,15 +280,15 @@ def recv_kv_caches_and_hidden_states(
return hidden_or_intermediate_states, bypass_model_exec, model_input


def adpat_model_input(
def build_partial_prefill_input(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
model_input: "ModelInputForGPUWithSamplingMetadata",
input_tokens_list: List[torch.Tensor],
num_computed_tokens_list: List[int],
start_pos_list: List[int],
slot_mapping_flat: torch.Tensor,
device: torch.device,
) -> ModelInputForGPUWithSamplingMetadata:
) -> "ModelInputForGPUWithSamplingMetadata":
rebuilt_input_tokens = []
rebuilt_input_positions= []
rebuilt_query_lens = []
Expand All @@ -290,6 +304,7 @@ def adpat_model_input(
rebuilt_context_lens_tensor = []
rebuilt_selected_token_indices = []

# recounting query and context lengths
for idx in range(len(input_tokens_list)):
token_tensor = input_tokens_list[idx]
num_token = len(token_tensor)
Expand Down Expand Up @@ -350,6 +365,8 @@ def adpat_model_input(
dtype=model_input.sampling_metadata.selected_token_indices.dtype,
).to(device)

# import here to avoid circular import.
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
rebuilt_model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens = torch.cat(rebuilt_input_tokens).to(device),
input_positions = torch.cat(rebuilt_input_positions).to(device),
Expand Down
11 changes: 6 additions & 5 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ def init_distributed_environment(
# this backend is used for WORLD
maybe_disagg_world_size = world_size
maybe_disagg_rank = rank
if dist_kv.IS_DISTRIBUTED_KV_INSTANCE:
if dist_kv.IS_DISTRIBUTED_KV_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE:
maybe_disagg_world_size = world_size * 2
logger.debug("Disaggregated prefill enabled.")
if dist_kv.IS_KV_PREFILL_INSTANCE:
if dist_kv.IS_KV_PREFILL_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE:
# for prefill, the ranks are [0, world_size)
maybe_disagg_rank = rank
else:
Expand Down Expand Up @@ -227,7 +227,7 @@ def init_distributed_environment(
if _WORLD is None:
ranks = [[i for i in range(world_size)]]
# offset the distributed group
if dist_kv.IS_DISTRIBUTED_KV_INSTANCE:
if dist_kv.IS_DISTRIBUTED_KV_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE:
ranks = include_decoding_groups_if_disagg_enabled(
ranks, world_size)

Expand Down Expand Up @@ -289,7 +289,7 @@ def initialize_model_parallel(
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
if dist_kv.IS_DISTRIBUTED_KV_INSTANCE:
if dist_kv.IS_DISTRIBUTED_KV_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE:
# Disaggregated prefill enabled
# The world_size for this vLLM instance is tp * pp, but torch.distributed contains 2 vLLM instances, its world size is 2 * tp * pp
# Adjust the world_size to match.
Expand Down Expand Up @@ -341,7 +341,8 @@ def initialize_model_parallel(
use_custom_allreduce=False)
logger.debug("_PP initialized for rank %d", torch.distributed.get_rank())

if dist_kv.IS_DISTRIBUTED_KV_INSTANCE:
# TODO(Jiayi): perhaps we need to separate lmcache and disagg
if dist_kv.IS_DISTRIBUTED_KV_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE:
global _DISAGG
logger.debug("Disaggregated prefill enabled, create _DISAGG group")
group_ranks = []
Expand Down

0 comments on commit 1f47731

Please sign in to comment.