Conversation
There was a problem hiding this comment.
Pull request overview
This PR extends the XPU FlashAttention2 path to support FP8 query (and typed Q/KV/O dispatch), including plumbing q_scale through the Python interface, C++ flash API, and XE2 Cutlass kernels.
Changes:
- Add
q_scale/q_descalesupport end-to-end for varlen attention, including XE2 chunk-prefill and paged-decode paths. - Refactor XE2 kernel dispatch to include output dtype (
CutlassQKOType) and generate typed kernel instantiations via CMake. - Expand/adjust FlashAttention tests to cover FP8 query cases.
Reviewed changes
Copilot reviewed 22 out of 22 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| vllm_xpu_kernels/flash_attn_interface.py | Plumbs q_descale into FA2 op call and enforces out is provided for FP8-Q. |
| tests/flash_attn/test_flash_attn_varlen_func.py | Adds FP8 query parametrization and updates reference/IO handling. |
| csrc/flash_attn/flash_api.cpp | Adds q_scale to the custom op signature and forwards it into XPU attention interfaces. |
| csrc/xpu/attn/attn_interface.h | Extends XPU attention interfaces to accept q_scale. |
| csrc/xpu/attn/attn_interface.cpp | Forwards q_scale through the XPU attention interface layer. |
| csrc/xpu/attn/xe_2/fmha_xe2.h | Extends XE2 chunk-prefill entrypoint signature with q_scale. |
| csrc/xpu/attn/xe_2/fmha_xe2.cpp | Passes q_scale into kernel args and switches to CutlassQKOType dispatch. |
| csrc/xpu/attn/xe_2/fmha_utils.hpp | Introduces CutlassQKOType and aten_to_Cutlass_qko_dtype. |
| csrc/xpu/attn/xe_2/chunk_prefill.hpp | Adds q_scale in args, switches sink/MMA to output dtype, adds typed dispatch helper. |
| csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp | Updates dispatch plumbing to use CutlassQKOType and include q_scale. |
| csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp | Declares extern templates for typed dispatch across allowed dtype combinations. |
| csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in | Updates explicit instantiation to typed dispatch template. |
| csrc/xpu/attn/xe_2/chunk_prefill_configure.cmake | Expands codegen across Q/KV/O dtype combinations. |
| csrc/xpu/attn/xe_2/paged_decode_xe2.h | Extends XE2 paged-decode entrypoint signature with q_scale. |
| csrc/xpu/attn/xe_2/paged_decode_xe2.cpp | Adds FP8-Q scale validation and passes q_scale into kernel args. |
| csrc/xpu/attn/xe_2/paged_decode.hpp | Adds q_scale in args, switches sink/MMA to output dtype, adds typed dispatch helper. |
| csrc/xpu/attn/xe_2/paged_decode_utils.hpp | Updates dispatch plumbing to use CutlassQKOType and include q_scale. |
| csrc/xpu/attn/xe_2/paged_decode_extern.hpp | Declares extern templates for typed paged-decode dispatch across allowed dtype combinations. |
| csrc/xpu/attn/xe_2/paged_decode_kernel_template.cpp.in | Updates explicit instantiation to typed dispatch template. |
| csrc/xpu/attn/xe_2/paged_decode_configure.cmake | Expands codegen across Q/KV/O dtype combinations. |
| csrc/xpu/attn/xe_2/collective/chunk_prefill_mainloop.hpp | Adds FP8-Q scaling in mainloop and converts scaling casts to output dtype. |
| csrc/xpu/attn/xe_2/collective/chunk_prefill_epilogue.hpp | Updates sink element type to use output tensor dtype. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| inline bool is_single_value_broadcast_tensor(const at::Tensor& t) { | ||
| if (t.scalar_type() != at::ScalarType::Float) { | ||
| return false; | ||
| } |
There was a problem hiding this comment.
is_single_value_broadcast_tensor can return true for empty tensors (e.g., numel==0) and the callers then dereference data_ptr() as if a scalar exists. Consider adding a t.numel() >= 1 check (and updating the downstream error messages which currently say “single element” even though expanded/broadcast views are now accepted).
| } | |
| } | |
| if (t.numel() < 1) { | |
| return false; | |
| } |
| # Allowed dtype combinations must match runtime dispatch constraints. Format: | ||
| # Q_TYPE|KV_TYPE|O_TYPE|FILE_TAG | ||
| set(dtype_combo_list | ||
| "half_t|half_t|half_t|h_h_h" | ||
| "half_t|float_e4m3_t|half_t|h_e4_h" | ||
| "half_t|float_e5m2_t|half_t|h_e5_h" | ||
| "bfloat16_t|bfloat16_t|bfloat16_t|b_b_b" | ||
| "bfloat16_t|float_e4m3_t|bfloat16_t|b_e4_b" | ||
| "bfloat16_t|float_e5m2_t|bfloat16_t|b_e5_b" | ||
| "float_e4m3_t|float_e4m3_t|half_t|e4_e4_h" | ||
| "float_e4m3_t|float_e4m3_t|bfloat16_t|e4_e4_b" | ||
| "float_e5m2_t|float_e5m2_t|half_t|e5_e5_h" | ||
| "float_e5m2_t|float_e5m2_t|bfloat16_t|e5_e5_b") | ||
|
|
||
| foreach(IMPL_POLICY ${policy_list}) | ||
| # foreach(IMPL_T ${L_TYPES}) | ||
| foreach(IMPL_KISPAGED ${L_BOOLS}) | ||
| foreach(IMPL_KISCAUSAL ${L_BOOLS}) | ||
| foreach(IMPL_KISLOCAL ${L_BOOLS}) | ||
| foreach(IMPL_KISSINK ${L_BOOLS}) | ||
| set(FILE_SUFFIX "${IMPL_POLICY}_") | ||
| set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISPAGED}}") | ||
| set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}") | ||
| set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}") | ||
| set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}") | ||
| configure_file(${FILENAME_SUFFIX}.cpp.in | ||
| "${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp") | ||
| list( | ||
| APPEND | ||
| GEN_KERNEL_SRCS | ||
| "${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp" | ||
| ) | ||
| foreach(dtype_combo ${dtype_combo_list}) | ||
| string(REPLACE "|" ";" dtype_parts "${dtype_combo}") | ||
| list(GET dtype_parts 0 IMPL_Q_T) | ||
| list(GET dtype_parts 1 IMPL_KV_T) | ||
| list(GET dtype_parts 2 IMPL_O_T) | ||
| list(GET dtype_parts 3 DTYPE_TAG) | ||
|
|
||
| foreach(IMPL_KISPAGED ${L_BOOLS}) | ||
| foreach(IMPL_KISCAUSAL ${L_BOOLS}) | ||
| foreach(IMPL_KISLOCAL ${L_BOOLS}) | ||
| foreach(IMPL_KISSINK ${L_BOOLS}) | ||
| set(FILE_SUFFIX "${IMPL_POLICY}_${DTYPE_TAG}_") | ||
| set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISPAGED}}") | ||
| set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}") | ||
| set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}") | ||
| set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}") |
There was a problem hiding this comment.
This change multiplies kernel instantiations by the 10-entry dtype_combo_list (now: 5 policies × 10 dtype combos × 16 bool combos = 800 generated sources). That can significantly increase build time and CI load. If possible, gate FP8-related instantiations behind a build option, or reduce the dtype list to only combinations that are actually reachable on XPU (or only those enabled by a CMake flag).
| @@ -74,26 +88,35 @@ function(paged_decode_configure FILENAME_SUFFIX) | |||
| set(IMPL_POLICY | |||
| ${policy_${IMPL_QGROUP}_${IMPL_HEADSIZE}_${IMPL_PAGESIZE}}) | |||
|
|
|||
| foreach(IMPL_KISCAUSAL ${L_BOOLS}) | |||
| foreach(IMPL_KISLOCAL ${L_BOOLS}) | |||
| foreach(IMPL_KISSINK ${L_BOOLS}) | |||
| # Construct unique filename suffix: e.g., _q8_h64_fff | |||
| set(FILE_SUFFIX | |||
| "_q${IMPL_QGROUP}_h${IMPL_HEADSIZE}_p${IMPL_PAGESIZE}_") | |||
| set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}") | |||
| set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}") | |||
| set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}") | |||
|
|
|||
| # Generate .cpp file from template | |||
| configure_file(${FILENAME_SUFFIX}.cpp.in | |||
| "${FILENAME_SUFFIX}${FILE_SUFFIX}.cpp") | |||
|
|
|||
| # Add to output list | |||
| list( | |||
| APPEND | |||
| GEN_KERNEL_SRCS | |||
| "${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}${FILE_SUFFIX}.cpp" | |||
| ) | |||
| foreach(dtype_combo ${dtype_combo_list}) | |||
| string(REPLACE "|" ";" dtype_parts "${dtype_combo}") | |||
| list(GET dtype_parts 0 IMPL_Q_T) | |||
| list(GET dtype_parts 1 IMPL_KV_T) | |||
| list(GET dtype_parts 2 IMPL_O_T) | |||
| list(GET dtype_parts 3 DTYPE_TAG) | |||
|
|
|||
| foreach(IMPL_KISCAUSAL ${L_BOOLS}) | |||
| foreach(IMPL_KISLOCAL ${L_BOOLS}) | |||
| foreach(IMPL_KISSINK ${L_BOOLS}) | |||
| # Construct unique filename suffix: e.g., _q8_h64_fff | |||
| set(FILE_SUFFIX | |||
| "_q${IMPL_QGROUP}_h${IMPL_HEADSIZE}_p${IMPL_PAGESIZE}_") | |||
| set(FILE_SUFFIX "${FILE_SUFFIX}${DTYPE_TAG}_") | |||
| set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}") | |||
| set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}") | |||
| set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}") | |||
There was a problem hiding this comment.
This adds dtype_combo_list to paged-decode codegen, expanding generated sources roughly 10× (2 qgroups × 5 heads × 2 pages × 10 dtype combos × 8 bool combos = 1600 files). That scale can materially impact compile times and binary size. Consider limiting instantiation to the dtype combos actually enabled/needed (e.g., behind an FP8 build flag) or generating only for architectures where FP8 is supported.
| is_fp8_q ? q_scale.value().data_ptr() : nullptr, | ||
| is_fp8_kv ? k_scale.value().data_ptr() : nullptr, | ||
| is_fp8_kv ? v_scale.value().data_ptr() : nullptr, |
There was a problem hiding this comment.
cutlass_chunk_prefill_impl dereferences q_scale.value(), k_scale.value(), and v_scale.value() when FP8 is detected, but there are no has_value() / dtype / shape checks here. If a caller passes FP8 Q/KV without providing the corresponding scales, this will throw before any useful error (and may bypass the Python-side checks). Add TORCH_CHECK(q_scale.has_value(), ...) for FP8 Q and TORCH_CHECK(k_scale.has_value() && v_scale.has_value(), ...) for FP8 KV (and validate float32 + single-value/broadcast semantics similar to paged_decode_xe2.cpp).
| bool is_fp8_q = q_type == at::ScalarType::Float8_e5m2 || | ||
| q_type == at::ScalarType::Float8_e4m3fn; | ||
| bool is_fp8kv = k_type == at::ScalarType::Float8_e5m2 || | ||
| k_type == at::ScalarType::Float8_e4m3fn; | ||
| if (is_fp8kv == is_fp8_q) { |
There was a problem hiding this comment.
The dtype validation here allows unsupported combinations (e.g., FP8 query with non-FP8 KV) because it only checks k.dtype == q.dtype when is_fp8kv == is_fp8_q. This can later crash in the XE2 kernels or hit q_scale.value() paths without proper guards. Tighten the checks to match the allowed runtime dispatch matrix (FP8 Q requires FP8 KV of the same FP8 dtype, and requires q_scale; FP8 KV requires both k_scale and v_scale).
| assert out is not None, \ | ||
| "output must be provided when q_descale is used" |
There was a problem hiding this comment.
Using assert for user-input validation (q_descale requiring out) is risky because asserts can be disabled with Python optimizations, and it raises an AssertionError rather than a stable API error. Prefer an explicit exception (e.g., ValueError) and, ideally, validate q_descale dtype/stride similarly to k_descale/v_descale.
| assert out is not None, \ | |
| "output must be provided when q_descale is used" | |
| if out is None: | |
| raise ValueError( | |
| "output must be provided when q_descale is used" | |
| ) | |
| if sum(q_descale.stride()) != 0 or \ | |
| q_descale.dtype != torch.float32: | |
| raise ValueError( | |
| "q_descale must be view of single float32 scalar tensor" | |
| ) |
| output = flash_attn_varlen_func(maybe_quantized_query, | ||
| maybe_quantized_key_cache, | ||
| maybe_quantized_value_cache, | ||
| max_query_len, | ||
| cu_query_lens, | ||
| max_kv_len, | ||
| seqused_k=seq_k, | ||
| softmax_scale=scale, | ||
| causal=False, | ||
| block_table=block_tables, | ||
| k_descale=k_descale, | ||
| v_descale=v_descale, | ||
| q_descale=q_descale.expand(scale_shape) | ||
| if q_descale is not None else None, | ||
| k_descale=k_descale.expand(scale_shape) | ||
| if k_descale is not None else None, | ||
| v_descale=v_descale.expand(scale_shape) | ||
| if v_descale is not None else None, | ||
| window_size=window_size, | ||
| s_aux=sink) |
There was a problem hiding this comment.
test_decode_with_paged_kv passes q_descale=... but does not pass an out= tensor to flash_attn_varlen_func. With the new interface, FA2 asserts out is not None when q_descale is used, so FP8-query cases will fail immediately. Allocate an output tensor with fp16/bf16 dtype and pass it via out=... (consistent with test_varlen_with_paged_kv).
| "fp8_dtypes": [torch.float8_e4m3fn, None], | ||
| "q_dtypes": [torch.float8_e4m3fn, None] |
There was a problem hiding this comment.
In MINI_PYTEST_PARAMS for test_varlen_with_paged_kv, the keys fp8_dtypes and q_dtypes don’t match the actual parametrized argument names (fp8_dtype, q_dtype). Under the MINI profiler mode, these overrides won’t apply and the test will run the full Cartesian product unexpectedly. Rename the keys to the real parameter names so pytest_generate_tests can override them.
| "fp8_dtypes": [torch.float8_e4m3fn, None], | |
| "q_dtypes": [torch.float8_e4m3fn, None] | |
| "fp8_dtype": [torch.float8_e4m3fn, None], | |
| "q_dtype": [torch.float8_e4m3fn, None] |
cbd474c to
d678a8a
Compare
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.
Purpose
depends on #153
Test Plan
Test Result
(Optional) Documentation Update
BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)