Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set rope to identity #1

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

set rope to identity #1

wants to merge 4 commits into from

Conversation

simon-mo
Copy link
Owner

@simon-mo simon-mo commented Jan 6, 2025

tests/test_mla_decode_kernel.py passes

@simon-mo
Copy link
Owner Author

simon-mo commented Jan 8, 2025

this is running test.py after forcing split kv

split_kv: 1
(decode_output - attn_output).abs().sum()=tensor(1064., device='cuda:0', dtype=torch.bfloat16)
wmape_value=0.19791666666666666
cos_similiarity=tensor(0.9766, device='cuda:0', dtype=torch.bfloat16)
split_kv: 0
(decode_output - attn_output).abs().sum()=tensor(1064., device='cuda:0', dtype=torch.bfloat16)
wmape_value=0.19791666666666666
cos_similiarity=tensor(0.9766, device='cuda:0', dtype=torch.bfloat16)

@cennn
Copy link

cennn commented Jan 13, 2025

I think I've pinpointed the problem. The distributions of q_pe and q_nope are significantly different from a normal distribution. Although the input hidden_states follows a normal distribution, after performing multiplication with W_DQ, applying q_a_layernorm, and then multiplying by W_QR, q_pe and q_nope are decidedly not normally distributed.

I carried out a simple test in which I simulated q_pe and q_nope as randomly initialized global variables. In the original test code within tests/test_mla_decode_kernel.py, it emerged that the similarity dropped to approximately 0.96, which is comparable to what we've encountered in our current situation.

Similarly, when I tested our code by commencing with hidden_states instead of q_pe and q_nope (while incorporating the W_DQ, q_a_layernorm, and W_QR matrices), the similarity returned to 1. To reproduce my results, you can simply comment out the initialization of q_pe and q_nope and substitute it with the following code:

# q_nope = query[:, :, :head_size].contiguous()
# assert q_nope.shape == (2, 16, 512)
# q_pe = query[:, :, head_size:].contiguous()
from torch import nn
class DeepseekV2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        DeepseekV2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return (self.weight * hidden_states).to(input_dtype)

bsz=2
hidden_size=5120
q_lora_rank=1536
num_heads=16
qk_rope_head_dim=64
qk_nope_head_dim=128
hidden_states = torch.randn([bsz, 1, hidden_size], device="cuda")
kv_lora_rank=512
W_DQ= torch.randn([hidden_size, q_lora_rank], device="cuda")
q_a_layernorm = DeepseekV2RMSNorm(q_lora_rank).to("cuda")
W_QR= torch.randn([q_lora_rank, num_heads*qk_rope_head_dim], device="cuda")
W_UQ= torch.randn([q_lora_rank, num_heads, qk_nope_head_dim], device="cuda")
W_UK= torch.randn([kv_lora_rank, num_heads, qk_nope_head_dim], device="cuda")
W_UQ_UK=torch.einsum("q n d, l n d -> q n l", W_UQ, W_UK).flatten(start_dim=1)

c_Q = torch.matmul(hidden_states, W_DQ)
c_Q = q_a_layernorm(c_Q)
q_pe = torch.matmul(c_Q, # c_Q ~ [bsz, q_lora_rank~1536]
                    W_QR) # W_QR ~ [1536, num_heads*qk_rope_head_dim]
# q_pe ~ [bsz, num_heads, qk_rope_head_dim]
q_pe = q_pe.reshape(bsz, num_heads, qk_rope_head_dim)
q_nope = torch.matmul(c_Q, 
                        W_UQ_UK) # W_UQ_UK~[1536, num_heads*kv_lora_rank]
# q_nope ~ [bsz, num_heads, kv_lora_rank]
q_nope = q_nope.reshape(bsz, num_heads, kv_lora_rank)

q_pe = q_pe.to(torch.bfloat16)
q_nope = q_nope.to(torch.bfloat16)

print(f"{q_pe.mean()=}, {q_pe.std()=}")
print(f"{q_nope.mean()=}, {q_nope.std()=}")

The results should look like this:
image

Signed-off-by: simon-mo <[email protected]>
@simon-mo
Copy link
Owner Author

Just pushed a commit with real activations. usage:

(vllm) simonmo@gcp5-h100-4-7:~/flashinfer$ python test.py
Randomly initializing
2025-01-13 07:31:33,276 - INFO - flashinfer.jit: Loading JIT ops: batch_decode_mla_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_512_use_swa_False_use_logits_cap_False
2025-01-13 07:31:34,536 - INFO - flashinfer.jit: Finished loading JIT ops: batch_decode_mla_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_512_use_swa_False_use_logits_cap_False
split_kv: 0
(decode_output - attn_output).abs().sum()=tensor(1064., device='cuda:0', dtype=torch.bfloat16)
wmape_value=0.19791666666666666
cos_similiarity=tensor(0.9766, device='cuda:0', dtype=torch.bfloat16)


(vllm) simonmo@gcp5-h100-4-7:~/flashinfer$ python test.py ./weights/layer_16.safetensors
Using weights from ./weights/layer_16.safetensors
2025-01-13 07:32:05,483 - INFO - flashinfer.jit: Loading JIT ops: batch_decode_mla_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_512_use_swa_False_use_logits_cap_False
2025-01-13 07:32:05,772 - INFO - flashinfer.jit: Finished loading JIT ops: batch_decode_mla_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_512_use_swa_False_use_logits_cap_False
split_kv: 0
(decode_output - attn_output).abs().sum()=tensor(780., device='cuda:0', dtype=torch.bfloat16)
wmape_value=0.22058823529411764
cos_similiarity=tensor(0.9727, device='cuda:0', dtype=torch.bfloat16)


(vllm) simonmo@gcp5-h100-4-7:~/flashinfer$ python test.py ./weights/layer_1.safetensors
Using weights from ./weights/layer_1.safetensors
2025-01-13 07:32:14,085 - INFO - flashinfer.jit: Loading JIT ops: batch_decode_mla_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_512_use_swa_False_use_logits_cap_False
2025-01-13 07:32:14,376 - INFO - flashinfer.jit: Finished loading JIT ops: batch_decode_mla_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_512_use_swa_False_use_logits_cap_False
split_kv: 0
(decode_output - attn_output).abs().sum()=tensor(74.5000, device='cuda:0', dtype=torch.bfloat16)
wmape_value=0.14607843137254903
cos_similiarity=tensor(0.9844, device='cuda:0', dtype=torch.bfloat16)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants