Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mseznec/flash attention fp8 #14570

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

mickaelseznec
Copy link

@mickaelseznec mickaelseznec commented Mar 10, 2025

This PR add support for FP8 KV cache with FlashAttention3 (related PR in flash-attn here) cc @LucasWilkinson Please do not merge this PR as long as it's not referencing vllm-project/flash-attention yet.

FlashAttention (contrary to FlashInfer) does attention with all Q, K and V in FP8.
The performance is usually better than FlashInfer FP8 KV and FlashAttention 3 with bf16.

I added support for v0 and v1 + some unit testing.

Note that I've added a trick for checkpoints not providing q_scale and reuse the k_scale (with is something TRTLLM does fwiw).

Also: I added a small QoS improvement when debugging v1: workers send back their traceback when they raise an exception.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
@mickaelseznec mickaelseznec force-pushed the mseznec/flash-attention-fp8 branch from bc909f9 to 2b985ed Compare March 10, 2025 15:41
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
@mickaelseznec
Copy link
Author

CI failing because vllm/tests/entrypoints/openai/test_accuracy.py from here doesn't exist.

@robertgshaw2-redhat any idea how should I fix? Just rename in run-tpu-test.sh? (@NickLucche you moved the file)

@NickLucche
Copy link
Contributor

This is a known issue, PR addressing it here #13898. It won't block your PR.

@NickLucche
Copy link
Contributor

I see there's some other problem with building the image, but likely CI just needs another spin

@LucasWilkinson
Copy link
Collaborator

@mickaelseznec apologies for the delay, vllm-project/flash-attention#50 (review) has been merged, you can now point to vllm_flash_attn

We will need to populate the sccache on the server to get it through the CI, I can help with this once the tag is updated 👍

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! Looks clean 😄, ill approve once we can get it updated to use vllm_flash_attn, added a couple comments


q_descale = q_scale.expand((num_seqs, num_kv_heads))
k_descale = k_scale.expand((num_seqs, num_kv_heads))
v_descale = v_scale.expand((num_seqs, num_kv_heads))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we maybe test per-head scales here too?, i.e. also test with non-zero strides

@@ -240,15 +240,6 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
target_backend = _Backend.XFORMERS
elif kv_cache_dtype is not None and \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should keep this check but restrict it to FA2, i.e. check get_flash_attn_version() != 2 (get_flash_attn_version() is in vllm/attention/backends/utils.py)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants