Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set rope to identity #1

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMLA(
num_threads, smem_size));
max_grid_size = num_blocks_per_sm * num_sm;
if (batch_size * gdy >= max_grid_size) {
// if (false) {
split_kv = false;
max_num_pages_per_batch = 1;
for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
Expand All @@ -247,6 +248,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMLA(
PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, gdy, num_pages,
std::max(128 / page_size, 1U));
if (new_batch_size == batch_size && !enable_cuda_graph) {
// if (false) {
// do not use partition-kv kernel for short sequence, when not using CUDAGraph
split_kv = false;
} else {
Expand All @@ -255,6 +257,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMLA(
}
}

printf("split_kv: %d\n", split_kv);

return cudaSuccess;
});
}
Expand Down
9 changes: 5 additions & 4 deletions include/flashinfer/pos_enc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,11 @@ __device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope_interleav
vec_before = vec;
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
float embed = float(offset) * freq[i];
float cos, sin;
__sincosf(embed, &sin, &cos);
vec[i] = vec[i] * cos + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin;
// float embed = float(offset) * freq[i];
// float cos, sin;
// __sincosf(embed, &sin, &cos);
// vec[i] = vec[i] * cos + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin;
vec[i] = vec[i] * 1.0f; // * cos + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin;
}
}
return vec;
Expand Down
130 changes: 130 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import torch
from math import sqrt

torch.random.manual_seed(42)

from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper

FLASHINFER_WORKSPACE_BUFFER_SIZE = 25 * 1024 * 1024

workspace = torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device="cuda")


# [total_blocks, 2, block_size, num_heads, head_dim] we store the rope cache in the "value" section
kv_cache = torch.zeros([2, 2, 16, 1, 512], dtype=torch.bfloat16, device="cuda")
query = torch.zeros([2, 16, 576], dtype=torch.bfloat16, device="cuda")

# here we load the real activations from a sample query using DeepseekV2-Lite-Chat
import safetensors.torch
import sys
if len(sys.argv) > 1:
path = sys.argv[1]
print(f"Using weights from {path}")
state_dict = safetensors.torch.load_file(path)
q_pe = state_dict["q_pe"]
q_nope = state_dict["q_nope"]
k_pe_cache = state_dict["k_pe_cache"]
compressed_kv_normed_cache = state_dict["compressed_kv_normed_cache"]
# print(f"{q_pe.shape=}, {q_noope.shape=}, {k_pe_cache.shape=}, {compressed_kv_normed_cache.shape=}")
# q_pe.shape=torch.Size([2, 16, 64]), q_nope.shape=torch.Size([2, 16, 512]), k_pe_cache.shape=torch.Size([2, 9, 64]), compressed_kv_normed_cache.shape=torch.Size([2, 9, 512])
query[:, :, :512] = q_nope
query[:, :, 512:] = q_pe
kv_cache[:, 0, :9, 0, :] = compressed_kv_normed_cache
kv_cache[:, 1, :9, 0, :64] = k_pe_cache
kv_cache[:, 1, :9, 0, 64:] = 0
else:
print("Randomly initializing")
kv_cache = torch.randn([2, 2, 16, 1, 512], dtype=torch.bfloat16, device="cuda")
query = torch.randn([2, 16, 576], dtype=torch.bfloat16, device="cuda")


wrapper = BatchDecodeMlaWithPagedKVCacheWrapper(workspace)

wrapper.plan(
indptr=torch.tensor([0, 1, 2], dtype=torch.int32, device="cuda"),
indices=torch.tensor([0, 1], dtype=torch.int32, device="cuda"),
last_page_len=torch.tensor([7, 9], dtype=torch.int32, device="cuda"),
num_qo_heads=16,
head_dim_compressed_kv=512,
page_size=16,
sm_scale=1 / sqrt(512 + 64),
data_type=query.dtype,
)

head_size = 512
q_nope = query[:, :, :head_size].contiguous()
assert q_nope.shape == (2, 16, 512)
q_pe = query[:, :, head_size:].contiguous()
assert q_pe.shape == (2, 16, 64)
paged_ckv_cache = kv_cache[:, 0].squeeze(2).contiguous()
assert paged_ckv_cache.shape == (2, 16, 512)
paged_kpe_cache = kv_cache[:, 1][..., 0, :64].contiguous()
assert paged_kpe_cache.shape == (2, 16, 64)

decode_output = wrapper.run(
q_nope=q_nope,
q_pe=q_pe,
paged_ckv_cache=paged_ckv_cache,
paged_kpe_cache=paged_kpe_cache,
)

k_pe_cache = torch.zeros(2,
9,
64,
device=kv_cache.device,
dtype=kv_cache.dtype)
k_pe_cache[0, :7, :] = kv_cache[0, 1, :7, 0, :64]
k_pe_cache[1, :9, :] = kv_cache[1, 1, :9, 0, :64]

compressed_kv_normed_cache = torch.zeros(2,
9,
512,
device=kv_cache.device,
dtype=kv_cache.dtype)
compressed_kv_normed_cache[0, :7, :] = kv_cache[0, 0, :7, 0, :]
compressed_kv_normed_cache[1, :9, :] = kv_cache[1, 0, :9, 0, :]

# attn_weights_pe ~ [bsz, 128, kv_len]
attn_weights_pe = torch.matmul(
q_pe, # [bsz, num_heads, qk_rope_head_dim]
k_pe_cache.transpose(
1, 2), # [bsz, kv_len, 64] view(bsz, kv_len, self.qk_rope_head_dim)
)
# attn_weights_nope ~ [bsz, 128, kv_len]
attn_weights_nope = torch.matmul(
q_nope, # [bsz, 128, 512]
compressed_kv_normed_cache.transpose(1, 2), # view(bsz, kv_len, 512)
)

attn_weights = (attn_weights_pe + attn_weights_nope) * 1 / sqrt(512 + 64)

attn_weights_sm = torch.nn.functional.softmax(
attn_weights,
dim=-1,
dtype=torch.float32,
).to(q_nope.dtype)

# attn_output ~ {attn_output.shape}") # [bsz, 128, 512]
attn_output = torch.matmul(
attn_weights_sm, # [bsz, 128, kv_len]
compressed_kv_normed_cache, # [bsz, kv_len, 512]
)

print(f"{(decode_output - attn_output).abs().sum()=}")

def wmape(target: torch.Tensor, preds: torch.Tensor):
sum_abs_error = (preds - target).abs().sum().detach().item()
sum_scale = target.abs().sum().detach().item()
return sum_abs_error / sum_scale


wmape_value = wmape(decode_output, attn_output)
# assert wmape_value < 0.02, wmape_value
print(f"{wmape_value=}")

cos_similiarity = torch.nn.functional.cosine_similarity(
decode_output.flatten(), attn_output.flatten(), dim=0)
# assert cos_similiarity > 0.99, cos_similiarity
print(f"{cos_similiarity=}")
14 changes: 11 additions & 3 deletions tests/test_mla_decode_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,25 @@ def run_proof_of_concept(
compressed_kv_normed_cache = compressed_kv_normed_cache.to(q_kv_dtype)
k_pe_cache = k_pe_cache.to(q_kv_dtype)

if not use_flashinfer_kernel:
if True: # True: #not use_flashinfer_kernel:
freqs_cis = precompute_freqs_cis(
self.qk_rope_head_dim, kv_len, self.rope_theta, use_scaled=False
).to(k_pe_cache.device)
dims_before = (q_pe.shape, k_pe_cache.shape)
dtype_before = (q_pe.dtype, k_pe_cache.dtype)
q_pe, k_pe_cache = apply_rotary_emb(
q_pe.unsqueeze(1).repeat(1, kv_len, 1, 1),
k_pe_cache.unsqueeze(2),
freqs_cis,
)
q_pe = q_pe[:, -1:, :, :].squeeze(1)
k_pe_cache = k_pe_cache.squeeze(2)
q_pe = q_pe[:, -1:, :, :].squeeze(1).contiguous()
k_pe_cache = k_pe_cache.squeeze(2).contiguous()
dims_after = (q_pe.shape, k_pe_cache.shape)
assert dims_before == dims_after
assert dtype_before == (q_pe.dtype, k_pe_cache.dtype)

if not use_flashinfer_kernel:


# attn_weights_pe ~ [bsz, 128, kv_len]
attn_weights_pe = torch.matmul(
Expand Down
Binary file added weights/layer_0.safetensors
Binary file not shown.
Binary file added weights/layer_1.safetensors
Binary file not shown.
Binary file added weights/layer_10.safetensors
Binary file not shown.
Binary file added weights/layer_11.safetensors
Binary file not shown.
Binary file added weights/layer_12.safetensors
Binary file not shown.
Binary file added weights/layer_13.safetensors
Binary file not shown.
Binary file added weights/layer_14.safetensors
Binary file not shown.
Binary file added weights/layer_15.safetensors
Binary file not shown.
Binary file added weights/layer_16.safetensors
Binary file not shown.
Binary file added weights/layer_17.safetensors
Binary file not shown.
Binary file added weights/layer_18.safetensors
Binary file not shown.
Binary file added weights/layer_19.safetensors
Binary file not shown.
Binary file added weights/layer_2.safetensors
Binary file not shown.
Binary file added weights/layer_20.safetensors
Binary file not shown.
Binary file added weights/layer_21.safetensors
Binary file not shown.
Binary file added weights/layer_22.safetensors
Binary file not shown.
Binary file added weights/layer_23.safetensors
Binary file not shown.
Binary file added weights/layer_24.safetensors
Binary file not shown.
Binary file added weights/layer_25.safetensors
Binary file not shown.
Binary file added weights/layer_26.safetensors
Binary file not shown.
Binary file added weights/layer_3.safetensors
Binary file not shown.
Binary file added weights/layer_4.safetensors
Binary file not shown.
Binary file added weights/layer_5.safetensors
Binary file not shown.
Binary file added weights/layer_6.safetensors
Binary file not shown.
Binary file added weights/layer_7.safetensors
Binary file not shown.
Binary file added weights/layer_8.safetensors
Binary file not shown.
Binary file added weights/layer_9.safetensors
Binary file not shown.