Skip to content

Commit cbd474c

Browse files
committed
add fp8 query test for pa
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
1 parent 9ce2174 commit cbd474c

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

tests/flash_attn/test_flash_attn_varlen_func.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,9 @@ def test_decode_with_paged_kv(
365365
if is_sink and window_size != (-1, -1):
366366
pytest.skip("sink not supported with sliding window")
367367
is_fp8_query = q_dtype is not None
368-
if is_fp8_query:
369-
pytest.skip("skip cases with fp8 query")
368+
is_fp8kv = fp8_dtype is not None
369+
if is_fp8_query and q_dtype != fp8_dtype:
370+
pytest.skip("skip cases with fp8 query and non-fp8 kv")
370371
torch.manual_seed(42)
371372
num_seqs = len(seq_lens)
372373
query_lens = [x[0] for x in seq_lens]
@@ -419,9 +420,11 @@ def test_decode_with_paged_kv(
419420
q_descale = torch.ones(scale_shape, dtype=torch.float32) #noqa: F841
420421
k_descale = torch.ones(scale_shape, dtype=torch.float32) #noqa: F841
421422
v_descale = torch.ones(scale_shape, dtype=torch.float32) #noqa: F841
422-
is_fp8kv = False
423-
if fp8_dtype is not None:
424-
is_fp8kv = True
423+
scale_shape = (num_seqs, num_kv_heads)
424+
if is_fp8_query:
425+
q_descale = (torch.abs(query).max() / 200).to(torch.float32)
426+
maybe_quantized_query = (query / q_descale).to(q_dtype)
427+
if is_fp8kv:
425428
k_descale = (torch.abs(key_cache).max() / 200).to(torch.float32)
426429
v_descale = (torch.abs(value_cache).max() / 200).to(torch.float32)
427430
maybe_quantized_key_cache = (key_cache / k_descale).to(fp8_dtype)
@@ -437,8 +440,12 @@ def test_decode_with_paged_kv(
437440
softmax_scale=scale,
438441
causal=False,
439442
block_table=block_tables,
440-
k_descale=k_descale,
441-
v_descale=v_descale,
443+
q_descale=q_descale.expand(scale_shape)
444+
if q_descale is not None else None,
445+
k_descale=k_descale.expand(scale_shape)
446+
if k_descale is not None else None,
447+
v_descale=v_descale.expand(scale_shape)
448+
if v_descale is not None else None,
442449
window_size=window_size,
443450
s_aux=sink)
444451

@@ -452,6 +459,7 @@ def test_decode_with_paged_kv(
452459
casual=False,
453460
is_paged=True,
454461
sink=sink,
462+
q_descale=q_descale,
455463
k_descale=k_descale,
456464
v_descale=v_descale,
457465
window_size_left=window_size[0],

0 commit comments

Comments
 (0)