Skip to content

[WIP] FP8 Query for PA#212

Open
xinyu-intel wants to merge 4 commits intovllm-project:mainfrom
xinyu-intel:dev/fp8-query-pa
Open

[WIP] FP8 Query for PA#212
xinyu-intel wants to merge 4 commits intovllm-project:mainfrom
xinyu-intel:dev/fp8-query-pa

Conversation

@xinyu-intel
Copy link
Copy Markdown
Collaborator

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.

Purpose

depends on #153

Test Plan

Test Result

(Optional) Documentation Update

BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)

Copilot AI review requested due to automatic review settings March 21, 2026 03:06
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the XPU FlashAttention2 path to support FP8 query (and typed Q/KV/O dispatch), including plumbing q_scale through the Python interface, C++ flash API, and XE2 Cutlass kernels.

Changes:

  • Add q_scale/q_descale support end-to-end for varlen attention, including XE2 chunk-prefill and paged-decode paths.
  • Refactor XE2 kernel dispatch to include output dtype (CutlassQKOType) and generate typed kernel instantiations via CMake.
  • Expand/adjust FlashAttention tests to cover FP8 query cases.

Reviewed changes

Copilot reviewed 22 out of 22 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
vllm_xpu_kernels/flash_attn_interface.py Plumbs q_descale into FA2 op call and enforces out is provided for FP8-Q.
tests/flash_attn/test_flash_attn_varlen_func.py Adds FP8 query parametrization and updates reference/IO handling.
csrc/flash_attn/flash_api.cpp Adds q_scale to the custom op signature and forwards it into XPU attention interfaces.
csrc/xpu/attn/attn_interface.h Extends XPU attention interfaces to accept q_scale.
csrc/xpu/attn/attn_interface.cpp Forwards q_scale through the XPU attention interface layer.
csrc/xpu/attn/xe_2/fmha_xe2.h Extends XE2 chunk-prefill entrypoint signature with q_scale.
csrc/xpu/attn/xe_2/fmha_xe2.cpp Passes q_scale into kernel args and switches to CutlassQKOType dispatch.
csrc/xpu/attn/xe_2/fmha_utils.hpp Introduces CutlassQKOType and aten_to_Cutlass_qko_dtype.
csrc/xpu/attn/xe_2/chunk_prefill.hpp Adds q_scale in args, switches sink/MMA to output dtype, adds typed dispatch helper.
csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp Updates dispatch plumbing to use CutlassQKOType and include q_scale.
csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp Declares extern templates for typed dispatch across allowed dtype combinations.
csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in Updates explicit instantiation to typed dispatch template.
csrc/xpu/attn/xe_2/chunk_prefill_configure.cmake Expands codegen across Q/KV/O dtype combinations.
csrc/xpu/attn/xe_2/paged_decode_xe2.h Extends XE2 paged-decode entrypoint signature with q_scale.
csrc/xpu/attn/xe_2/paged_decode_xe2.cpp Adds FP8-Q scale validation and passes q_scale into kernel args.
csrc/xpu/attn/xe_2/paged_decode.hpp Adds q_scale in args, switches sink/MMA to output dtype, adds typed dispatch helper.
csrc/xpu/attn/xe_2/paged_decode_utils.hpp Updates dispatch plumbing to use CutlassQKOType and include q_scale.
csrc/xpu/attn/xe_2/paged_decode_extern.hpp Declares extern templates for typed paged-decode dispatch across allowed dtype combinations.
csrc/xpu/attn/xe_2/paged_decode_kernel_template.cpp.in Updates explicit instantiation to typed dispatch template.
csrc/xpu/attn/xe_2/paged_decode_configure.cmake Expands codegen across Q/KV/O dtype combinations.
csrc/xpu/attn/xe_2/collective/chunk_prefill_mainloop.hpp Adds FP8-Q scaling in mainloop and converts scaling casts to output dtype.
csrc/xpu/attn/xe_2/collective/chunk_prefill_epilogue.hpp Updates sink element type to use output tensor dtype.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

inline bool is_single_value_broadcast_tensor(const at::Tensor& t) {
if (t.scalar_type() != at::ScalarType::Float) {
return false;
}
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_single_value_broadcast_tensor can return true for empty tensors (e.g., numel==0) and the callers then dereference data_ptr() as if a scalar exists. Consider adding a t.numel() >= 1 check (and updating the downstream error messages which currently say “single element” even though expanded/broadcast views are now accepted).

Suggested change
}
}
if (t.numel() < 1) {
return false;
}

Copilot uses AI. Check for mistakes.
Comment on lines +10 to +40
# Allowed dtype combinations must match runtime dispatch constraints. Format:
# Q_TYPE|KV_TYPE|O_TYPE|FILE_TAG
set(dtype_combo_list
"half_t|half_t|half_t|h_h_h"
"half_t|float_e4m3_t|half_t|h_e4_h"
"half_t|float_e5m2_t|half_t|h_e5_h"
"bfloat16_t|bfloat16_t|bfloat16_t|b_b_b"
"bfloat16_t|float_e4m3_t|bfloat16_t|b_e4_b"
"bfloat16_t|float_e5m2_t|bfloat16_t|b_e5_b"
"float_e4m3_t|float_e4m3_t|half_t|e4_e4_h"
"float_e4m3_t|float_e4m3_t|bfloat16_t|e4_e4_b"
"float_e5m2_t|float_e5m2_t|half_t|e5_e5_h"
"float_e5m2_t|float_e5m2_t|bfloat16_t|e5_e5_b")

foreach(IMPL_POLICY ${policy_list})
# foreach(IMPL_T ${L_TYPES})
foreach(IMPL_KISPAGED ${L_BOOLS})
foreach(IMPL_KISCAUSAL ${L_BOOLS})
foreach(IMPL_KISLOCAL ${L_BOOLS})
foreach(IMPL_KISSINK ${L_BOOLS})
set(FILE_SUFFIX "${IMPL_POLICY}_")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISPAGED}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}")
configure_file(${FILENAME_SUFFIX}.cpp.in
"${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp")
list(
APPEND
GEN_KERNEL_SRCS
"${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp"
)
foreach(dtype_combo ${dtype_combo_list})
string(REPLACE "|" ";" dtype_parts "${dtype_combo}")
list(GET dtype_parts 0 IMPL_Q_T)
list(GET dtype_parts 1 IMPL_KV_T)
list(GET dtype_parts 2 IMPL_O_T)
list(GET dtype_parts 3 DTYPE_TAG)

foreach(IMPL_KISPAGED ${L_BOOLS})
foreach(IMPL_KISCAUSAL ${L_BOOLS})
foreach(IMPL_KISLOCAL ${L_BOOLS})
foreach(IMPL_KISSINK ${L_BOOLS})
set(FILE_SUFFIX "${IMPL_POLICY}_${DTYPE_TAG}_")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISPAGED}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}")
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change multiplies kernel instantiations by the 10-entry dtype_combo_list (now: 5 policies × 10 dtype combos × 16 bool combos = 800 generated sources). That can significantly increase build time and CI load. If possible, gate FP8-related instantiations behind a build option, or reduce the dtype list to only combinations that are actually reachable on XPU (or only those enabled by a CMake flag).

Copilot uses AI. Check for mistakes.
Comment on lines 65 to +107
@@ -74,26 +88,35 @@ function(paged_decode_configure FILENAME_SUFFIX)
set(IMPL_POLICY
${policy_${IMPL_QGROUP}_${IMPL_HEADSIZE}_${IMPL_PAGESIZE}})

foreach(IMPL_KISCAUSAL ${L_BOOLS})
foreach(IMPL_KISLOCAL ${L_BOOLS})
foreach(IMPL_KISSINK ${L_BOOLS})
# Construct unique filename suffix: e.g., _q8_h64_fff
set(FILE_SUFFIX
"_q${IMPL_QGROUP}_h${IMPL_HEADSIZE}_p${IMPL_PAGESIZE}_")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}")

# Generate .cpp file from template
configure_file(${FILENAME_SUFFIX}.cpp.in
"${FILENAME_SUFFIX}${FILE_SUFFIX}.cpp")

# Add to output list
list(
APPEND
GEN_KERNEL_SRCS
"${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}${FILE_SUFFIX}.cpp"
)
foreach(dtype_combo ${dtype_combo_list})
string(REPLACE "|" ";" dtype_parts "${dtype_combo}")
list(GET dtype_parts 0 IMPL_Q_T)
list(GET dtype_parts 1 IMPL_KV_T)
list(GET dtype_parts 2 IMPL_O_T)
list(GET dtype_parts 3 DTYPE_TAG)

foreach(IMPL_KISCAUSAL ${L_BOOLS})
foreach(IMPL_KISLOCAL ${L_BOOLS})
foreach(IMPL_KISSINK ${L_BOOLS})
# Construct unique filename suffix: e.g., _q8_h64_fff
set(FILE_SUFFIX
"_q${IMPL_QGROUP}_h${IMPL_HEADSIZE}_p${IMPL_PAGESIZE}_")
set(FILE_SUFFIX "${FILE_SUFFIX}${DTYPE_TAG}_")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}")
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds dtype_combo_list to paged-decode codegen, expanding generated sources roughly 10× (2 qgroups × 5 heads × 2 pages × 10 dtype combos × 8 bool combos = 1600 files). That scale can materially impact compile times and binary size. Consider limiting instantiation to the dtype combos actually enabled/needed (e.g., behind an FP8 build flag) or generating only for architectures where FP8 is supported.

Copilot uses AI. Check for mistakes.
Comment on lines +137 to 139
is_fp8_q ? q_scale.value().data_ptr() : nullptr,
is_fp8_kv ? k_scale.value().data_ptr() : nullptr,
is_fp8_kv ? v_scale.value().data_ptr() : nullptr,
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cutlass_chunk_prefill_impl dereferences q_scale.value(), k_scale.value(), and v_scale.value() when FP8 is detected, but there are no has_value() / dtype / shape checks here. If a caller passes FP8 Q/KV without providing the corresponding scales, this will throw before any useful error (and may bypass the Python-side checks). Add TORCH_CHECK(q_scale.has_value(), ...) for FP8 Q and TORCH_CHECK(k_scale.has_value() && v_scale.has_value(), ...) for FP8 KV (and validate float32 + single-value/broadcast semantics similar to paged_decode_xe2.cpp).

Copilot uses AI. Check for mistakes.
Comment on lines +71 to +75
bool is_fp8_q = q_type == at::ScalarType::Float8_e5m2 ||
q_type == at::ScalarType::Float8_e4m3fn;
bool is_fp8kv = k_type == at::ScalarType::Float8_e5m2 ||
k_type == at::ScalarType::Float8_e4m3fn;
if (is_fp8kv == is_fp8_q) {
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dtype validation here allows unsupported combinations (e.g., FP8 query with non-FP8 KV) because it only checks k.dtype == q.dtype when is_fp8kv == is_fp8_q. This can later crash in the XE2 kernels or hit q_scale.value() paths without proper guards. Tighten the checks to match the allowed runtime dispatch matrix (FP8 Q requires FP8 KV of the same FP8 dtype, and requires q_scale; FP8 KV requires both k_scale and v_scale).

Copilot uses AI. Check for mistakes.
Comment on lines +116 to +117
assert out is not None, \
"output must be provided when q_descale is used"
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using assert for user-input validation (q_descale requiring out) is risky because asserts can be disabled with Python optimizations, and it raises an AssertionError rather than a stable API error. Prefer an explicit exception (e.g., ValueError) and, ideally, validate q_descale dtype/stride similarly to k_descale/v_descale.

Suggested change
assert out is not None, \
"output must be provided when q_descale is used"
if out is None:
raise ValueError(
"output must be provided when q_descale is used"
)
if sum(q_descale.stride()) != 0 or \
q_descale.dtype != torch.float32:
raise ValueError(
"q_descale must be view of single float32 scalar tensor"
)

Copilot uses AI. Check for mistakes.
Comment on lines 433 to 450
output = flash_attn_varlen_func(maybe_quantized_query,
maybe_quantized_key_cache,
maybe_quantized_value_cache,
max_query_len,
cu_query_lens,
max_kv_len,
seqused_k=seq_k,
softmax_scale=scale,
causal=False,
block_table=block_tables,
k_descale=k_descale,
v_descale=v_descale,
q_descale=q_descale.expand(scale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(scale_shape)
if k_descale is not None else None,
v_descale=v_descale.expand(scale_shape)
if v_descale is not None else None,
window_size=window_size,
s_aux=sink)
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_decode_with_paged_kv passes q_descale=... but does not pass an out= tensor to flash_attn_varlen_func. With the new interface, FA2 asserts out is not None when q_descale is used, so FP8-query cases will fail immediately. Allocate an output tensor with fp16/bf16 dtype and pass it via out=... (consistent with test_varlen_with_paged_kv).

Copilot uses AI. Check for mistakes.
Comment on lines +133 to +134
"fp8_dtypes": [torch.float8_e4m3fn, None],
"q_dtypes": [torch.float8_e4m3fn, None]
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In MINI_PYTEST_PARAMS for test_varlen_with_paged_kv, the keys fp8_dtypes and q_dtypes don’t match the actual parametrized argument names (fp8_dtype, q_dtype). Under the MINI profiler mode, these overrides won’t apply and the test will run the full Cartesian product unexpectedly. Rename the keys to the real parameter names so pytest_generate_tests can override them.

Suggested change
"fp8_dtypes": [torch.float8_e4m3fn, None],
"q_dtypes": [torch.float8_e4m3fn, None]
"fp8_dtype": [torch.float8_e4m3fn, None],
"q_dtype": [torch.float8_e4m3fn, None]

Copilot uses AI. Check for mistakes.
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants