Add mixed precision support for LoRA expand & shrink kernels#230
Add mixed precision support for LoRA expand & shrink kernels#230chaojun-zhang wants to merge 1 commit intovllm-project:mainfrom
Conversation
Signed-off-by: chaojun-zhang <chaojun.zhang@intel.com>
There was a problem hiding this comment.
Pull request overview
Adds mixed-precision support to the XPU LoRA BGMV kernels so they can accept float32 inputs with fp16/bf16 LoRA weights, and introduces a new test variant to exercise this path.
Changes:
- Extend
bgmv_shrink/bgmv_expand_slicekernels to dispatch input and weight dtypes independently (allowing float32 inputs with fp16/bf16 weights). - Update test data generation utilities to allow separate
input_dtypevsweight_dtype. - Add a new
test_kernels_mixed_precisiontest matrix and mini-profiler params.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
tests/test_lora_ops.py |
Adds mixed-precision test coverage and threads input_dtype through kernel checks. |
tests/lora/utils.py |
Allows generating inputs with a different dtype than the LoRA weights. |
csrc/xpu/lora/lora_shrink.cpp |
Enables mixed dtype dispatch for shrink (independent input/weight types) and updates validation. |
csrc/xpu/lora/lora_expand.cpp |
Enables mixed dtype dispatch for expand-slice and updates validation. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| """ | ||
| Tests LoRA kernels with mixed precision: | ||
| input=float32, weight=float16/bfloat16. | ||
| """ |
There was a problem hiding this comment.
This PR’s description template is still unfilled (missing Purpose and Test Result details). Please update the PR description so reviewers can understand intent and validate changes more quickly.
| // Use the minimum vec_size so both types can be vectorized | ||
| uint32_t vec_size = std::min(input_vec_size, weight_vec_size); |
There was a problem hiding this comment.
This file now uses std::min to compute vec_size, but it does not include <algorithm>. Please add the direct header include to avoid relying on transitive includes (which can break builds under different toolchains).
| using weight_t = std::remove_const_t< | ||
| std::remove_pointer_t<decltype(weight_ptr)>>; | ||
| VLLM_DISPATCH_FLOATING_TYPES( |
There was a problem hiding this comment.
The new dispatch logic uses std::remove_const_t / std::remove_pointer_t but this file does not include <type_traits>. Add the explicit include to avoid build fragility from transitive includes.
| // Use the minimum vec_size so both types can be vectorized | ||
| uint32_t vec_size = std::min(input_vec_size, weight_vec_size); |
There was a problem hiding this comment.
This file now uses std::min to compute vec_size, but it does not include <algorithm>. Please add the direct header include to avoid relying on transitive includes (which can break builds under different toolchains).
| auto dispatch_output = [&](auto* weight_ptr) { | ||
| using weight_t = std::remove_const_t< | ||
| std::remove_pointer_t<decltype(weight_ptr)>>; | ||
| switch (outputs.scalar_type()) { |
There was a problem hiding this comment.
The new dispatch logic uses std::remove_const_t / std::remove_pointer_t but this file does not include <type_traits>. Add the explicit include to avoid build fragility from transitive includes.
| """ | ||
| Tests LoRA kernels with mixed precision: | ||
| input=float32, weight=float16/bfloat16. | ||
| """ |
There was a problem hiding this comment.
The new mixed-precision test sets input_dtype=torch.float32, but the Torch reference path (tests/lora/torch_ops.py) casts inputs to output_tensor.dtype (fp16/bf16) before computing. That means this test may not actually validate the float32-input execution path (it may just compare against a fp16/bf16 reference). Consider updating the reference calculation (or adding a special-case in this test) so the reference keeps float32 inputs and only casts the final result to the output dtype as needed.
… and inputs
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.
Purpose
Add mixed precision support for LoRA expand & shrink kernels (used in vllm ) with float32 inputs and float16 weights.
Test Plan
pytest -s -v tests/test_lora_ops::test_kernels_mixed_precision
vllm side: pytest -s -v tests/lora/test_layers.py
Test Result
(Optional) Documentation Update
BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)