-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Attention.forward with unified_attention when use_direct_call=True #11967
Changes from all commits
f1c10d1
291cc7a
1f915ca
4e135d6
bbac1fb
c8283cd
631bce7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
from vllm.logger import init_logger | ||
from vllm.model_executor import set_random_seed | ||
from vllm.sequence import ExecuteModelRequest | ||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size | ||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size | ||
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner | ||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, | ||
LoraNotSupportedWorkerBase, WorkerBase, | ||
|
@@ -108,6 +108,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: | |
torch.tensor([], dtype=torch.float32, | ||
device=self.device)) | ||
for _ in range(num_layers)] | ||
bind_kv_cache(self.compilation_config.static_forward_context, | ||
[kv_caches]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TPU kv caches are actually two separate tensors. would that matter here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does not matter. The detailed structure of kv_cache in one layer (e.g., torch.Tensor/Tuple[torch.Tensor]) will not be accessed outside Attention backends. |
||
self.model_runner._dummy_run( | ||
batch_size=1, | ||
seq_len=self.scheduler_config.max_num_batched_tokens, | ||
|
@@ -170,6 +172,8 @@ def initialize_cache( | |
device="cpu") | ||
cpu_v_cache = torch.zeros_like(cpu_k_cache) | ||
self.cpu_cache.append((cpu_k_cache, cpu_v_cache)) | ||
bind_kv_cache(self.compilation_config.static_forward_context, | ||
[self.tpu_cache]) | ||
self._warmup_model() | ||
|
||
def _warmup_model(self) -> None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this extra list construction for pipeline parallel? if yes, can you add comments here?