Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qwen2vl vision encoder fix #2365

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

jakep-allenai
Copy link
Contributor

@jakep-allenai jakep-allenai commented Dec 5, 2024

Potential fix for #2112

Motivation

Users have reported worse performance running qwen2-vl in sglang and vllm than with transformers. I have identified a few cases of different calculations in the vision encoder. And now, it should be matching perfectly.

Currently this is a draft PR, because performance is reduced by half roughly.

Modifications

  • Reverted to exactly the same QuickGELU implementation as HF transformers.
  • Fixed weird casting issues with attention module used in vision network
  • Going back to Torch SDPA

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

Copy link
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for investigating this!

@@ -30,10 +30,12 @@
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from vllm.config import CacheConfig, MultiModalConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused imports

from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused imports

q = q.squeeze(0)
k = k.squeeze(0)
v = v.squeeze(0)
output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think pytorch SDPA should also be fast, the probably is probably how you prepare the attention mask?
Can you vectorize the code more, use less Python for-loop, or write a triton kernel for it (see example), or catch the results so we can reuse it across layers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I have tried caching the attention mask, but it doesn't seem to impact performance much.

The issue I see is that for one layer, the context_attention_fwd kernel in sglang matches torch's scaled_dot_product_attention pretty closely, within 1e-2 for each activation. But, in qwen2vl, there are 32 layers, and after a while, the absolute difference accumulates higher, closer to +/- 1.0 max absolute difference in the activations.

"""

def forward(self, input: torch.Tensor) -> torch.Tensor:
return input * torch.sigmoid(1.702 * input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try torch.compile to fuse them?

QWEN2_VL_MODEL = "Qwen/Qwen2-VL-7B-Instruct"


class RawSGLangTest(unittest.IsolatedAsyncioTestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. Maybe give it a better name.

Now we have a very good reference script for text-only models and a very good model support guide:
https://github.com/sgl-project/sglang/blob/main/scripts/playground/reference_hf.py
https://sgl-project.github.io/references/supported_models.html#how-to-support-a-new-model

Are you willing to help here to add some scripts/docs similar to the above ones, but for vision language models?

@merrymercy
Copy link
Contributor

@jakep-allenai Do you have any updates on this? Qwen2vl is a very popular model so we would like to fix it soon.

@jakep-allenai
Copy link
Contributor Author

No, my implementation with F.fused_dot_product_attention was still 1/2 the speed after caching, and even then, I never heard back from @Mr-Loevan about rerunning his benchmark to see if it would fix his reported issue. On our side, we found no significant difference in user-preference of generations with vllm (which used the xformers backend) or with sglang.

My current theory is that the memory-efficient attention implementation in sglang is accurate enough for a single layer, but small errors will accumulate for a typical 30+ layer full network.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants