Open
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
Adds support for a new fused activation (swiglustep_and_mul) across the XPU extension stack (C++/SYCL kernel → Torch binding → Python dispatch), plus accompanying unit test and benchmark.
Changes:
- Add
swiglustepactivation option to the fused MoE Python interface. - Register and bind a new
torch.ops._C.swiglustep_and_mulXPU operator and implement its SYCL kernel. - Add unit test coverage and a benchmark script for the new op.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| vllm_xpu_kernels/fused_moe_interface.py | Routes activation="swiglustep" to the new fused op. |
| csrc/activation.cpp | Implements swiglustep_and_mul device function + kernel + launcher. |
| csrc/torch_bindings.cpp | Registers the new op schema and XPU implementation. |
| csrc/ops.h | Declares the new C++ op entrypoint. |
| tests/register_ops.py | Adds a Python test wrapper for the new op. |
| tests/ops/swiglustep_and_mul_op.py | Adds a CustomOp test harness + native reference implementation. |
| tests/test_swiglustep_and_mul.py | Adds pytest coverage + opcheck for the new op. |
| benchmark/benchmark_swiglustep_and_mul.py | Adds performance benchmarking for the op vs native/compile. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
Comment on lines
+12
to
+15
| XPU_DEVICES = [ | ||
| f"xpu:{i}" for i in range(1 if torch.xpu.device_count() == 1 else 2) | ||
| ] | ||
|
|
Comment on lines
+54
to
+57
| torch.set_default_device(device) | ||
| x = torch.randn(num_tokens, 2 * d, dtype=dtype) | ||
|
|
||
| layer = SwigluStepAndMul() |
Comment on lines
+76
to
+77
| d = x.shape[-1] // 2 | ||
| output_shape = (x.shape[:-1] + (d, )) |
Comment on lines
+298
to
+299
| elif activation == "swiglustep": | ||
| torch.ops._C.swiglustep_and_mul(act_output, gemm1_output, 7.0) |
Comment on lines
+11
to
+13
| from tests.ops.swiglustep_and_mul_op import SwigluStepAndMul | ||
|
|
||
|
|
Signed-off-by: Qiao, Zhefeng <zhefeng.qiao@intel.com>
21a7552 to
bef820a
Compare
xinyu-intel
reviewed
Mar 19, 2026
csrc/activation.cpp
Outdated
| torch::Tensor& input, // [..., 2 * d] | ||
| double limit) { | ||
| LAUNCH_SWIGLUSTEP_AND_MUL(vllm::swiglustep_and_mul, limit); | ||
| } No newline at end of file |
Signed-off-by: Qiao, Zhefeng <zhefeng.qiao@intel.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.
Purpose
Support swiglustep and mul
Test Plan
python -m pytest tests/test_swiglustep_and_mul.py -v
Test Result
Pass
(Optional) Documentation Update
BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)