Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] flashinfer's RMSNorm implementation causes precision differences in model outputs compared to the HuggingFace implementation #2258

Closed
5 tasks done
BBuf opened this issue Nov 29, 2024 · 14 comments
Assignees
Labels
bug Something isn't working

Comments

@BBuf
Copy link
Collaborator

BBuf commented Nov 29, 2024

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

We are serving a custom bf16 llama3 8b model with tp2 in GTX 4090 with SGLang, then we discovered that using flashinfer's RMSNorm implementation resulted in low-quality outputs, but switching to a naive RMSNorm implementation could avoid these low-quality outputs. Therefore, there might be precision issues with flashinfer's RMSNorm implementation. We are using the official flashinfer version 0.1.6.

Reproduction

custom bf16 llama3 8b model with tp2 in GTX 4090

Environment

Python: 3.12.7 (main, Oct  1 2024, 08:52:12) [GCC 11.4.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA GeForce RTX 4090
GPU 0,1,2,3,4,5,6,7 Compute Capability: 8.9
CUDA_HOME: /usr/local/cuda
NVCC: Not Available
CUDA Driver Version: 535.154.05
PyTorch: 2.4.0+cu121
sglang: 0.3.5.post2
flashinfer: 0.1.6
triton: 3.0.0
transformers: 4.46.2
requests: 2.32.3
tqdm: 4.67.0
numpy: 1.26.4
aiohttp: 3.11.2
fastapi: 0.115.5
hf_transfer: 0.1.8
huggingface_hub: 0.26.2
interegular: 0.3.3
packaging: 24.2
PIL: 10.4.0
psutil: 6.1.0
pydantic: 2.9.2
uvicorn: 0.32.0
uvloop: 0.21.0
zmq: 26.2.0
vllm: 0.6.3.post1
multipart: 0.0.17
openai: 1.54.4
anthropic: 0.39.0
NVIDIA Topology: 
        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      PIX     SYS     SYS     SYS     SYS     SYS     SYS     0-31,64-95      0               N/A
GPU1    PIX      X      SYS     SYS     SYS     SYS     SYS     SYS     0-31,64-95      0               N/A
GPU2    SYS     SYS      X      PIX     SYS     SYS     SYS     SYS     0-31,64-95      0               N/A
GPU3    SYS     SYS     PIX      X      SYS     SYS     SYS     SYS     0-31,64-95      0               N/A
GPU4    SYS     SYS     SYS     SYS      X      PIX     SYS     SYS     32-63,96-127    1               N/A
GPU5    SYS     SYS     SYS     SYS     PIX      X      SYS     SYS     32-63,96-127    1               N/A
GPU6    SYS     SYS     SYS     SYS     SYS     SYS      X      PIX     32-63,96-127    1               N/A
GPU7    SYS     SYS     SYS     SYS     SYS     SYS     PIX      X      32-63,96-127    1               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

ulimit soft: 655350
@BBuf
Copy link
Collaborator Author

BBuf commented Nov 29, 2024

How can I compile flashinfer from source? I had try this:

图片

But no c++ compile process, and it directly install flashinfer v0.1.6:

图片

@zhyncs
Copy link
Member

zhyncs commented Nov 29, 2024

@zhyncs zhyncs self-assigned this Nov 29, 2024
@zhyncs zhyncs added the bug Something isn't working label Nov 29, 2024
@BBuf
Copy link
Collaborator Author

BBuf commented Nov 30, 2024

@BBuf May you try the nightly version https://github.com/flashinfer-ai/flashinfer-nightly/releases/tag/0.1.6%2B4ade6f3

Thanks! I'll have a try.

@BBuf
Copy link
Collaborator Author

BBuf commented Nov 30, 2024

@zhyncs

With flahsinfer nightly, first I get this error:

图片

I bypass this error with the following changes, but then encountered a new error:

图片

[2024-11-29 22:45:06 TP0] Load weight end. type=MistralForCausalLM, dtype=torch.bfloat16, avail mem=11.39 GB
[2024-11-29 22:45:06 TP1] Load weight end. type=MistralForCausalLM, dtype=torch.bfloat16, avail mem=11.39 GB
[2024-11-29 22:45:06 TP1] Memory pool end. avail mem=4.45 GB
[2024-11-29 22:45:06 TP0] Memory pool end. avail mem=4.45 GB
[2024-11-29 22:45:06 TP1] max_total_num_tokens=88831, max_prefill_tokens=16384, max_running_requests=4097, context_len=4096
[2024-11-29 22:45:06 TP0] max_total_num_tokens=88831, max_prefill_tokens=16384, max_running_requests=4097, context_len=4096
[2024-11-29 22:45:06] INFO:     Started server process [791965]
[2024-11-29 22:45:06] INFO:     Waiting for application startup.
[2024-11-29 22:45:06] INFO:     Application startup complete.
[2024-11-29 22:45:06] INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
[2024-11-29 22:45:07] INFO:     127.0.0.1:52246 - "GET /get_model_info HTTP/1.1" 200 OK
[2024-11-29 22:45:07 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2024-11-29 22:45:08 TP0] Traceback (most recent call last):
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/managers/scheduler.py", line 1259, in run_scheduler_process
    scheduler.event_loop_normal()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/managers/scheduler.py", line 353, in event_loop_normal
    result = self.run_batch(batch)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/managers/scheduler.py", line 812, in run_batch
    logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/managers/tp_worker.py", line 139, in forward_batch_generation
    logits_output = self.model_runner.forward(forward_batch)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/model_executor/model_runner.py", line 599, in forward
    return self.forward_extend(forward_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/model_executor/model_runner.py", line 583, in forward_extend
    return self.model.forward(
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/models/llama.py", line 318, in forward
    hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/models/llama.py", line 283, in forward
    hidden_states, residual = layer(
                              ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/models/llama.py", line 233, in forward
    hidden_states = self.self_attn(
                    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/models/llama.py", line 170, in forward
    attn_output = self.attn(q, k, v, forward_batch)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/layers/radix_attention.py", line 60, in forward
    return forward_batch.attn_backend.forward(q, k, v, self, forward_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/layers/attention/__init__.py", line 60, in forward
    return self.forward_extend(q, k, v, layer, forward_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/layers/attention/flashinfer_backend.py", line 236, in forward_extend
    o = prefill_wrapper_paged.forward(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/flashinfer/prefill.py", line 1144, in forward
    return self.run(q, paged_kv_cache, k_scale=k_scale, v_scale=v_scale)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/flashinfer/prefill.py", line 1211, in run
    _check_cached_qkv_data_type(
  File "/usr/local/lib/python3.12/dist-packages/flashinfer/utils.py", line 199, in _check_cached_qkv_data_type
    raise ValueError(
ValueError: The dtype of q torch.bfloat16 does not match the q_data_type torch.float16 specified in plan function.

[2024-11-29 22:45:08 TP1] Traceback (most recent call last):
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/managers/scheduler.py", line 1259, in run_scheduler_process
    scheduler.event_loop_normal()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/managers/scheduler.py", line 353, in event_loop_normal
    result = self.run_batch(batch)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/managers/scheduler.py", line 812, in run_batch
    logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/managers/tp_worker.py", line 139, in forward_batch_generation
    logits_output = self.model_runner.forward(forward_batch)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/model_executor/model_runner.py", line 599, in forward
    return self.forward_extend(forward_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/model_executor/model_runner.py", line 583, in forward_extend
    return self.model.forward(
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/models/llama.py", line 318, in forward
    hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/models/llama.py", line 283, in forward
    hidden_states, residual = layer(
                              ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/models/llama.py", line 233, in forward
    hidden_states = self.self_attn(
                    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/models/llama.py", line 170, in forward
    attn_output = self.attn(q, k, v, forward_batch)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/layers/radix_attention.py", line 60, in forward
    return forward_batch.attn_backend.forward(q, k, v, self, forward_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/layers/attention/__init__.py", line 60, in forward
    return self.forward_extend(q, k, v, layer, forward_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/bbuf/upstream-vllm/sglang/python/sglang/srt/layers/attention/flashinfer_backend.py", line 236, in forward_extend
    o = prefill_wrapper_paged.forward(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/flashinfer/prefill.py", line 1144, in forward
    return self.run(q, paged_kv_cache, k_scale=k_scale, v_scale=v_scale)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/flashinfer/prefill.py", line 1211, in run
    _check_cached_qkv_data_type(
  File "/usr/local/lib/python3.12/dist-packages/flashinfer/utils.py", line 199, in _check_cached_qkv_data_type
    raise ValueError(
ValueError: The dtype of q torch.bfloat16 does not match the q_data_type torch.float16 specified in plan function.

Killed

@zhyncs
Copy link
Member

zhyncs commented Dec 1, 2024

It has been fixed with #2179

@zhyncs
Copy link
Member

zhyncs commented Dec 1, 2024

The dtype inconsistency issue can be forcibly specified as dtyp fp16.

@BBuf
Copy link
Collaborator Author

BBuf commented Dec 1, 2024

The dtype inconsistency issue can be forcibly specified as dtyp fp16.

Ok, I'll have a try later.

@yzh119
Copy link
Collaborator

yzh119 commented Dec 1, 2024

@BBuf there was an numerical issue in fused_add_rmsnorm function (and the rmsnorm function in flashinfer is okay) and has been fixed in flashinfer-ai/flashinfer#587.

But no c++ compile process, and it directly install flashinfer v0.1.6

That is because flashinfer introduced JIT compilation recently, which do not compile kernels when pip install. If you want to compile kernels ahead of time, check AOT-mode.

btw, it's better to ask flashinfer related questions in flashinfer issues. Otherwise I might miss them.

@BBuf
Copy link
Collaborator Author

BBuf commented Dec 1, 2024

@BBuf there was an numerical issue in fused_add_rmsnorm function (and the rmsnorm function in flashinfer is okay) and has been fixed in flashinfer-ai/flashinfer#587.

But no c++ compile process, and it directly install flashinfer v0.1.6

That is because flashinfer introduced JIT compilation recently, which do not compile kernels when pip install. If you want to compile kernels ahead of time, check AOT-mode.

btw, it's better to ask flashinfer related questions in flashinfer issues. Otherwise I might miss them.

I have confirmed that the fix above resolves our precision issue when deploying llama3 model with tp2. Thank you all @zhyncs @yzh119 @zhaochenyang20 @merrymercy

@BBuf BBuf closed this as completed Dec 1, 2024
@zzh-www
Copy link

zzh-www commented Dec 6, 2024

hey, i want to know if we deploy llama3 or qwen2 with tp1, will this issue be reproduced?

@BBuf
Copy link
Collaborator Author

BBuf commented Dec 6, 2024

hey, i want to know if we deploy llama3 or qwen2 with tp1, will this issue be reproduced?

We did not observe this situation with qwen2, which might be related to the way we pretrained llama3.

@rbao2018
Copy link

rbao2018 commented Dec 9, 2024

hey, i want to know if we deploy llama3 or qwen2 with tp1, will this issue be reproduced?

We did not observe this situation with qwen2, which might be related to the way we pretrained llama3.

Thank you very much for your work.

BTW, how can we avoid this problem when using sglang==0.4.0 + flashinfer==0.1.6? Or do we have to use flashinfer-nightly as you mentioned in your Zhihu post?

@BBuf
Copy link
Collaborator Author

BBuf commented Dec 13, 2024

hey, i want to know if we deploy llama3 or qwen2 with tp1, will this issue be reproduced?

We did not observe this situation with qwen2, which might be related to the way we pretrained llama3.

Thank you very much for your work.

BTW, how can we avoid this problem when using sglang==0.4.0 + flashinfer==0.1.6? Or do we have to use flashinfer-nightly as you mentioned in your Zhihu post?

You can install flashinfer-nightly, but you need to verify whether your model outputs are affected by this RMSNorm precision issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants