-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Any plans to support tree attention mask? #924
Comments
Sure, we'll just need someone to contribute :D |
I'm keen to try supporting a generic mask case, like [B, Q, K] bool, and doing conditional execution. Ideally this covers quite a lot of masking cases, but I guess optimised kernels would work better for more structured masks (like Tree). |
I don't see much difference between a generic mask and a structured mask. For a tree mask, the mask argument would also be of [B, K, Q]. In the 4d attention mask mentioned above, it's nothing but [b, h, k, q] h being number of head. If you are able to implement a generic mask, then a structured mask will be ready |
What I mean is that for a structured mask you don't necessarily have to create a bool tensor. In the casual case it can be hardcoded in the kernel to ignore j>i+k_cache, which saves a little bit of memory. If its structured the locations you'll visit are predictable. |
I see. Yeah, in the causal mask case, indeed the bool tensor mask argument is not required. For the tree attention mask, however, this argument will be inevitable. But probably this doesn't increase much implementation complexity, since the causal mask will internally be converted to such tensor anyway. @thorinf Look forward to your PR! |
Hello, sorry for the naive question but:
It might help me understand this a bit more :) |
please check out this blogpost with 4D masks description |
Tree attention mask is already supported in huggingface/transformers: huggingface/transformers#27539
It will be very helpful for the speculative decoding applications. More sepcifically, in
flash_attn/flash_attn_interface.py#flash_attn_with_kvcache
, the tree attention mask will need to be specified and passed in as an argument.Do you have any near plans to support it?
Thanks
Related questions: #840, #918
The text was updated successfully, but these errors were encountered: