@@ -365,8 +365,9 @@ def test_decode_with_paged_kv(
365365 if is_sink and window_size != (- 1 , - 1 ):
366366 pytest .skip ("sink not supported with sliding window" )
367367 is_fp8_query = q_dtype is not None
368- if is_fp8_query :
369- pytest .skip ("skip cases with fp8 query" )
368+ is_fp8kv = fp8_dtype is not None
369+ if is_fp8_query and q_dtype != fp8_dtype :
370+ pytest .skip ("skip cases with fp8 query and non-fp8 kv" )
370371 torch .manual_seed (42 )
371372 num_seqs = len (seq_lens )
372373 query_lens = [x [0 ] for x in seq_lens ]
@@ -419,9 +420,11 @@ def test_decode_with_paged_kv(
419420 q_descale = torch .ones (scale_shape , dtype = torch .float32 ) #noqa: F841
420421 k_descale = torch .ones (scale_shape , dtype = torch .float32 ) #noqa: F841
421422 v_descale = torch .ones (scale_shape , dtype = torch .float32 ) #noqa: F841
422- is_fp8kv = False
423- if fp8_dtype is not None :
424- is_fp8kv = True
423+ scale_shape = (num_seqs , num_kv_heads )
424+ if is_fp8_query :
425+ q_descale = (torch .abs (query ).max () / 200 ).to (torch .float32 )
426+ maybe_quantized_query = (query / q_descale ).to (q_dtype )
427+ if is_fp8kv :
425428 k_descale = (torch .abs (key_cache ).max () / 200 ).to (torch .float32 )
426429 v_descale = (torch .abs (value_cache ).max () / 200 ).to (torch .float32 )
427430 maybe_quantized_key_cache = (key_cache / k_descale ).to (fp8_dtype )
@@ -437,8 +440,12 @@ def test_decode_with_paged_kv(
437440 softmax_scale = scale ,
438441 causal = False ,
439442 block_table = block_tables ,
440- k_descale = k_descale ,
441- v_descale = v_descale ,
443+ q_descale = q_descale .expand (scale_shape )
444+ if q_descale is not None else None ,
445+ k_descale = k_descale .expand (scale_shape )
446+ if k_descale is not None else None ,
447+ v_descale = v_descale .expand (scale_shape )
448+ if v_descale is not None else None ,
442449 window_size = window_size ,
443450 s_aux = sink )
444451
@@ -452,6 +459,7 @@ def test_decode_with_paged_kv(
452459 casual = False ,
453460 is_paged = True ,
454461 sink = sink ,
462+ q_descale = q_descale ,
455463 k_descale = k_descale ,
456464 v_descale = v_descale ,
457465 window_size_left = window_size [0 ],
0 commit comments