-
-
Notifications
You must be signed in to change notification settings - Fork 6.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] GGUF MoE kernel #14613
[Kernel] GGUF MoE kernel #14613
Conversation
Signed-off-by: SzymonOzog <szymon.ozog@aleph-alpha.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM! Just some nits about ops registration, PTAL!
template <typename scalar_t, int qk, int qr, int qi, bool need_sum, | ||
typename block_q_t, int mmq_x, int mmq_y, int nwarps, | ||
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, | ||
int vdr, vec_dot_q_mul_mat_cuda_t vec_dot> | ||
static __device__ __forceinline__ void moe_q( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this file adapted/copied from somewhere? If so, we need to add the source of it for easier maintenance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just adapted from the mmq kernel that's already in the repo, not sure if I should mention that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's still fine to mention it since there's no such kernel in llama.cpp, so that other developers interested in this kernel won't be confused. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure thing, added paths to both files I took inspiration from
else: | ||
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a warning about performance degradation for this fallback if user using i-matrix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, added a warning
Signed-off-by: SzymonOzog <szymon.ozog@aleph-alpha.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing achievement! We should explore evals and benchmarks to detail the compression tradeoffs for users
Overally speeds up DeepSeek GGUF and enables graph caching. Jumps from 10 to 50 tok/s on 8xH100 for Q4_K quants