Skip to content

Commit

Permalink
Add miscellaneous updates (vllm-project#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Mar 13, 2023
1 parent de10960 commit cd9f1ac
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 22 deletions.
15 changes: 8 additions & 7 deletions cacheflow/master/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def step(self) -> None:
# 3. Join new sequences if possible.
# NOTE: Here we implicitly assume FCFS scheduling.
# TODO(woosuk): Add a batching policy to control the batch size.
self._fetch_inputs()
if not self.swapped:
self._fetch_inputs()
for i, seq_group in enumerate(self.pending):
num_prompt_tokens = seq_group.seqs[0].get_len()
if self.block_manager.can_allocate(seq_group):
Expand Down Expand Up @@ -211,12 +211,13 @@ def step(self) -> None:
input_seq_groups.append(input_seq_group)

# 5. Execute the first stage of the pipeline.
self.controllers[0].execute_stage(
input_seq_groups,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
)
if (input_seq_groups or blocks_to_swap_in or blocks_to_swap_out):
self.controllers[0].execute_stage(
input_seq_groups,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
)

def post_step(
self,
Expand Down
11 changes: 5 additions & 6 deletions cacheflow/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class OPTCacheFlowAttention(nn.Module):

def __init__(self, scale: float) -> None:
super().__init__()
super(OPTCacheFlowAttention, self).__init__()
self.scale = float(scale)

self.flash_attn = FlashAttention(softmax_scale=self.scale)
Expand Down Expand Up @@ -106,8 +106,8 @@ def forward(
output = output.view(-1, num_heads, head_size)

# Compute the attention op for prompts.
if input_metadata.num_prompts > 0:
num_prompt_tokens = sum(input_metadata.prompt_lens)
num_prompt_tokens = input_metadata.num_prompt_tokens
if num_prompt_tokens > 0:
self.multi_query_kv_attention(
output[:num_prompt_tokens],
query[:num_prompt_tokens],
Expand All @@ -126,10 +126,9 @@ def forward(

if input_metadata.num_generation_tokens > 0:
# Compute the attention op for generation tokens.
start_idx = sum(input_metadata.prompt_lens)
self.single_query_cached_kv_attention(
output[start_idx:],
query[start_idx:],
output[num_prompt_tokens:],
query[num_prompt_tokens:],
key_cache,
value_cache,
input_metadata)
Expand Down
18 changes: 14 additions & 4 deletions cacheflow/models/memory_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from cacheflow.models.utils import get_dtype_size
from cacheflow.models.utils import get_gpu_memory

_GiB = 1 << 30
_GiB = 1 << 30


class CacheFlowMemoryAnalyzer:
Expand Down Expand Up @@ -117,9 +117,19 @@ def get_max_num_gpu_blocks(

def get_max_num_cpu_blocks(
self,
memory_utilization: float = 0.25,
swap_space: int,
) -> int:
swap_space = swap_space * _GiB
cpu_memory = get_cpu_memory()
usable_memory = int(memory_utilization * cpu_memory)
max_num_blocks = usable_memory // self._get_cache_block_size()
if swap_space > 0.8 * cpu_memory:
raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) '
'takes more than 80% of the available memory '
f'({cpu_memory / _GiB:.2f} GiB).'
'Please check the swap space size.')
if swap_space > 0.5 * cpu_memory:
print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) '
'takes more than 50% of the available memory '
f'({cpu_memory / _GiB:.2f} GiB).'
'This may slow the system performance.')
max_num_blocks = swap_space // self._get_cache_block_size()
return max_num_blocks
2 changes: 1 addition & 1 deletion cacheflow/models/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class Sampler(nn.Module):

def __init__(self) -> None:
super().__init__()
super(Sampler, self).__init__()

def forward(
self,
Expand Down
7 changes: 7 additions & 0 deletions cacheflow/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,13 @@ def execute_stage(
else:
cache_events = None

# If there is no input, we don't need to execute the model.
if not input_seq_groups:
if cache_events is not None:
for event in cache_events:
event.wait()
return {}

# Prepare input tensors.
input_tokens, input_positions, input_metadata = self.prepare_inputs(
input_seq_groups)
Expand Down
7 changes: 5 additions & 2 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include <torch/extension.h>

#include <ATen/cuda/CUDAContext.h>

#include <algorithm>
Expand Down Expand Up @@ -73,6 +72,8 @@ void copy_blocks(
}
}

namespace cacheflow {

template<typename scalar_t>
__global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
Expand Down Expand Up @@ -112,6 +113,8 @@ __global__ void reshape_and_cache_kernel(
}
}

} // namespace cacheflow

void reshape_and_cache(
torch::Tensor& key,
torch::Tensor& value,
Expand All @@ -131,7 +134,7 @@ void reshape_and_cache(
key.scalar_type(),
"reshape_and_cache_kernel",
[&] {
reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
Expand Down
6 changes: 4 additions & 2 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--max-batch-size', type=int, default=2048, help='maximum number of batched tokens')
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens')
args = parser.parse_args()


Expand All @@ -27,7 +28,8 @@ def main():
)
num_gpu_blocks = memory_analyzer.get_max_num_gpu_blocks(
max_num_batched_tokens=args.max_batch_size)
num_cpu_blocks = memory_analyzer.get_max_num_cpu_blocks()
num_cpu_blocks = memory_analyzer.get_max_num_cpu_blocks(
swap_space=args.swap_space)
print(f'# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}')

# Create a controller for each node.
Expand Down

0 comments on commit cd9f1ac

Please sign in to comment.