-
Notifications
You must be signed in to change notification settings - Fork 711
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] flashinfer's RMSNorm implementation causes precision differences in model outputs compared to the HuggingFace implementation #2258
Comments
@BBuf May you try the nightly version https://github.com/flashinfer-ai/flashinfer-nightly/releases/tag/0.1.6%2B4ade6f3 |
Thanks! I'll have a try. |
With flahsinfer nightly, first I get this error: I bypass this error with the following changes, but then encountered a new error: [2024-11-29 22:45:06 TP0] Load weight end. type=MistralForCausalLM, dtype=torch.bfloat16, avail mem=11.39 GB
[2024-11-29 22:45:06 TP1] Load weight end. type=MistralForCausalLM, dtype=torch.bfloat16, avail mem=11.39 GB
[2024-11-29 22:45:06 TP1] Memory pool end. avail mem=4.45 GB
[2024-11-29 22:45:06 TP0] Memory pool end. avail mem=4.45 GB
[2024-11-29 22:45:06 TP1] max_total_num_tokens=88831, max_prefill_tokens=16384, max_running_requests=4097, context_len=4096
[2024-11-29 22:45:06 TP0] max_total_num_tokens=88831, max_prefill_tokens=16384, max_running_requests=4097, context_len=4096
[2024-11-29 22:45:06] INFO: Started server process [791965]
[2024-11-29 22:45:06] INFO: Waiting for application startup.
[2024-11-29 22:45:06] INFO: Application startup complete.
[2024-11-29 22:45:06] INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
[2024-11-29 22:45:07] INFO: 127.0.0.1:52246 - "GET /get_model_info HTTP/1.1" 200 OK
[2024-11-29 22:45:07 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-29 22:45:08 TP0] Traceback (most recent call last):
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/managers/scheduler.py", line 1259, in run_scheduler_process
scheduler.event_loop_normal()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/managers/scheduler.py", line 353, in event_loop_normal
result = self.run_batch(batch)
^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/managers/scheduler.py", line 812, in run_batch
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/managers/tp_worker.py", line 139, in forward_batch_generation
logits_output = self.model_runner.forward(forward_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/model_executor/model_runner.py", line 599, in forward
return self.forward_extend(forward_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/model_executor/model_runner.py", line 583, in forward_extend
return self.model.forward(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/models/llama.py", line 318, in forward
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/models/llama.py", line 283, in forward
hidden_states, residual = layer(
^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/models/llama.py", line 233, in forward
hidden_states = self.self_attn(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/models/llama.py", line 170, in forward
attn_output = self.attn(q, k, v, forward_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/layers/radix_attention.py", line 60, in forward
return forward_batch.attn_backend.forward(q, k, v, self, forward_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/layers/attention/__init__.py", line 60, in forward
return self.forward_extend(q, k, v, layer, forward_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/layers/attention/flashinfer_backend.py", line 236, in forward_extend
o = prefill_wrapper_paged.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/flashinfer/prefill.py", line 1144, in forward
return self.run(q, paged_kv_cache, k_scale=k_scale, v_scale=v_scale)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/flashinfer/prefill.py", line 1211, in run
_check_cached_qkv_data_type(
File "/usr/local/lib/python3.12/dist-packages/flashinfer/utils.py", line 199, in _check_cached_qkv_data_type
raise ValueError(
ValueError: The dtype of q torch.bfloat16 does not match the q_data_type torch.float16 specified in plan function.
[2024-11-29 22:45:08 TP1] Traceback (most recent call last):
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/managers/scheduler.py", line 1259, in run_scheduler_process
scheduler.event_loop_normal()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/managers/scheduler.py", line 353, in event_loop_normal
result = self.run_batch(batch)
^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/managers/scheduler.py", line 812, in run_batch
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/managers/tp_worker.py", line 139, in forward_batch_generation
logits_output = self.model_runner.forward(forward_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/model_executor/model_runner.py", line 599, in forward
return self.forward_extend(forward_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/model_executor/model_runner.py", line 583, in forward_extend
return self.model.forward(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/models/llama.py", line 318, in forward
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/models/llama.py", line 283, in forward
hidden_states, residual = layer(
^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/models/llama.py", line 233, in forward
hidden_states = self.self_attn(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/models/llama.py", line 170, in forward
attn_output = self.attn(q, k, v, forward_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/layers/radix_attention.py", line 60, in forward
return forward_batch.attn_backend.forward(q, k, v, self, forward_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/layers/attention/__init__.py", line 60, in forward
return self.forward_extend(q, k, v, layer, forward_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/layers/attention/flashinfer_backend.py", line 236, in forward_extend
o = prefill_wrapper_paged.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/flashinfer/prefill.py", line 1144, in forward
return self.run(q, paged_kv_cache, k_scale=k_scale, v_scale=v_scale)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/flashinfer/prefill.py", line 1211, in run
_check_cached_qkv_data_type(
File "/usr/local/lib/python3.12/dist-packages/flashinfer/utils.py", line 199, in _check_cached_qkv_data_type
raise ValueError(
ValueError: The dtype of q torch.bfloat16 does not match the q_data_type torch.float16 specified in plan function.
Killed |
It has been fixed with #2179 |
The dtype inconsistency issue can be forcibly specified as dtyp fp16. |
Ok, I'll have a try later. |
@BBuf there was an numerical issue in
That is because flashinfer introduced JIT compilation recently, which do not compile kernels when pip install. If you want to compile kernels ahead of time, check AOT-mode. btw, it's better to ask flashinfer related questions in flashinfer issues. Otherwise I might miss them. |
I have confirmed that the fix above resolves our precision issue when deploying llama3 model with tp2. Thank you all @zhyncs @yzh119 @zhaochenyang20 @merrymercy |
hey, i want to know if we deploy llama3 or qwen2 with tp1, will this issue be reproduced? |
We did not observe this situation with qwen2, which might be related to the way we pretrained llama3. |
Thank you very much for your work. BTW, how can we avoid this problem when using sglang==0.4.0 + flashinfer==0.1.6? Or do we have to use flashinfer-nightly as you mentioned in your Zhihu post? |
You can install flashinfer-nightly, but you need to verify whether your model outputs are affected by this RMSNorm precision issue. |
Checklist
Describe the bug
We are serving a custom bf16 llama3 8b model with tp2 in GTX 4090 with SGLang, then we discovered that using flashinfer's RMSNorm implementation resulted in low-quality outputs, but switching to a naive RMSNorm implementation could avoid these low-quality outputs. Therefore, there might be precision issues with flashinfer's RMSNorm implementation. We are using the official flashinfer version 0.1.6.
Reproduction
custom bf16 llama3 8b model with tp2 in GTX 4090
Environment
The text was updated successfully, but these errors were encountered: