-
-
Notifications
You must be signed in to change notification settings - Fork 6.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mseznec/flash attention fp8 #14570
base: main
Are you sure you want to change the base?
Mseznec/flash attention fp8 #14570
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
bc909f9
to
2b985ed
Compare
CI failing because @robertgshaw2-redhat any idea how should I fix? Just rename in run-tpu-test.sh? (@NickLucche you moved the file) |
This is a known issue, PR addressing it here #13898. It won't block your PR. |
I see there's some other problem with building the image, but likely CI just needs another spin |
@mickaelseznec apologies for the delay, vllm-project/flash-attention#50 (review) has been merged, you can now point to vllm_flash_attn We will need to populate the sccache on the server to get it through the CI, I can help with this once the tag is updated 👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! Looks clean 😄, ill approve once we can get it updated to use vllm_flash_attn
, added a couple comments
|
||
q_descale = q_scale.expand((num_seqs, num_kv_heads)) | ||
k_descale = k_scale.expand((num_seqs, num_kv_heads)) | ||
v_descale = v_scale.expand((num_seqs, num_kv_heads)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could we maybe test per-head scales here too?, i.e. also test with non-zero strides
@@ -240,15 +240,6 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, | |||
"Cannot use FlashAttention-2 backend for dtype other than " | |||
"torch.float16 or torch.bfloat16.") | |||
target_backend = _Backend.XFORMERS | |||
elif kv_cache_dtype is not None and \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should keep this check but restrict it to FA2, i.e. check get_flash_attn_version() != 2
(get_flash_attn_version()
is in vllm/attention/backends/utils.py
)
This PR add support for FP8 KV cache with FlashAttention3 (related PR in flash-attn here) cc @LucasWilkinson Please do not merge this PR as long as it's not referencing vllm-project/flash-attention yet.
FlashAttention (contrary to FlashInfer) does attention with all Q, K and V in FP8.
The performance is usually better than FlashInfer FP8 KV and FlashAttention 3 with bf16.
I added support for v0 and v1 + some unit testing.
Note that I've added a trick for checkpoints not providing q_scale and reuse the k_scale (with is something TRTLLM does fwiw).
Also: I added a small QoS improvement when debugging v1: workers send back their traceback when they raise an exception.