Skip to content

Commit c3df9d2

Browse files
committed
support fp8 query
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
1 parent 80bd4c7 commit c3df9d2

14 files changed

+172
-62
lines changed

csrc/flash_attn/flash_api.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ std::vector<at::Tensor> mha_varlen_fwd(
4949
int max_seqlen_q,
5050
int max_seqlen_k,
5151
float p_dropout,
52+
std::optional<const at::Tensor>& q_scale,
5253
std::optional<const at::Tensor>& k_scale,
5354
std::optional<const at::Tensor>& v_scale,
5455
float softmax_scale,
@@ -63,21 +64,17 @@ std::vector<at::Tensor> mha_varlen_fwd(
6364
std::optional<int> num_splits) {
6465
auto q_type = q.scalar_type();
6566
auto k_type = k.scalar_type();
66-
TORCH_CHECK(
67-
q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
68-
"VLLM Kernel XPU only supports fp16 and bf16 type");
67+
auto v_type = v.scalar_type();
6968

7069
TORCH_CHECK(
7170
v.scalar_type() == k_type, "key and value must have the same dtype");
72-
bool is_fp8kv = false;
73-
if (k_type == at::ScalarType::Float8_e5m2 ||
74-
k_type == at::ScalarType::Float8_e4m3fn) {
75-
is_fp8kv = true;
76-
} else {
71+
bool is_fp8_q = q_type == at::ScalarType::Float8_e5m2 ||
72+
q_type == at::ScalarType::Float8_e4m3fn;
73+
bool is_fp8kv = k_type == at::ScalarType::Float8_e5m2 ||
74+
k_type == at::ScalarType::Float8_e4m3fn;
75+
if (is_fp8kv == is_fp8_q) {
7776
TORCH_CHECK(
7877
k.scalar_type() == q_type, "query and key must have the same dtype");
79-
TORCH_CHECK(
80-
v.scalar_type() == q_type, "query and value must have the same dtype");
8178
}
8279

8380
CHECK_DEVICE(q);
@@ -128,6 +125,10 @@ std::vector<at::Tensor> mha_varlen_fwd(
128125
} else {
129126
out = torch::empty_like(q);
130127
}
128+
TORCH_CHECK(
129+
out.scalar_type() == at::ScalarType::Half ||
130+
out.scalar_type() == at::ScalarType::BFloat16,
131+
"VLLM Kernel XPU only supports fp16 and bf16 type");
131132

132133
bool is_varlen = true;
133134
bool is_local = (window_size_left != -1) | (window_size_right != -1);
@@ -147,6 +148,7 @@ std::vector<at::Tensor> mha_varlen_fwd(
147148
seqlens_k,
148149
max_seqlen_q,
149150
max_seqlen_k,
151+
q_scale,
150152
k_scale,
151153
v_scale,
152154
softmax_scale,
@@ -227,8 +229,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
227229
"cu_seqlens_q, "
228230
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor? "
229231
"block_table, Tensor? alibi_slopes, "
230-
"int max_seqlen_q, int max_seqlen_k, float p_dropout, Tensor? k_scale, "
231-
"Tensor? v_scale, "
232+
"int max_seqlen_q, int max_seqlen_k, float p_dropout, Tensor? q_scale, "
233+
"Tensor? k_scale, Tensor? v_scale, "
232234
"float softmax_scale, Tensor? softmax_sink, bool zero_tensors, "
233235
"bool is_causal, int window_size_left, int window_size_right, float "
234236
"softcap, bool return_softmax, "

csrc/xpu/attn/attn_interface.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ void cutlass_chunk_prefill_interface(
1717
const at::Tensor& cu_seqlens_k,
1818
int max_seqlen_q,
1919
int max_seqlen_k,
20+
std::optional<const at::Tensor>& q_scale,
2021
std::optional<const at::Tensor>& k_scale,
2122
std::optional<const at::Tensor>& v_scale,
2223
double sm_scale,
@@ -42,6 +43,7 @@ void cutlass_chunk_prefill_interface(
4243
cu_seqlens_k,
4344
max_seqlen_q,
4445
max_seqlen_k,
46+
q_scale,
4547
k_scale,
4648
v_scale,
4749
sm_scale,

csrc/xpu/attn/attn_interface.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ void cutlass_chunk_prefill_interface(
1010
const at::Tensor& cu_seqlens_k,
1111
int max_seqlen_q,
1212
int max_seqlen_k,
13+
std::optional<const at::Tensor>& q_scale,
1314
std::optional<const at::Tensor>& k_scale,
1415
std::optional<const at::Tensor>& v_scale,
1516
double sm_scale,

csrc/xpu/attn/xe_2/chunk_prefill.hpp

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ struct chunk_prefill_args_t {
3636
int max_keys;
3737
int total_seqlen_q;
3838
int total_seqlen_k;
39+
void* q_scale;
3940
void* k_scale;
4041
void* v_scale;
4142
float sm_scale;
@@ -145,8 +146,9 @@ struct KernelLauncher {
145146
stride_V,
146147
reinterpret_cast<ElementO*>(args.out),
147148
stride_O,
148-
reinterpret_cast<ElementQ*>(args.sm_sink)},
149+
reinterpret_cast<ElementO*>(args.sm_sink)},
149150
{args.sm_scale,
151+
args.q_scale,
150152
args.k_scale,
151153
args.v_scale,
152154
static_cast<int*>(args.block_table),
@@ -232,9 +234,10 @@ template <
232234
struct FMHAConfig {
233235
static constexpr int SGTileQ =
234236
get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})))();
237+
// Note that always use output dtype for MMAOperation
235238
using MMAOperation = cute::conditional_t<
236239
is_void_v<MMAOperation_>,
237-
XE_DPAS_TT<cute::gcd(SGTileQ, 8), float, ElementQ>,
240+
XE_DPAS_TT<cute::gcd(SGTileQ, 8), float, ElementO>,
238241
MMAOperation_>;
239242
using SubgroupLayoutPV = cute::conditional_t<
240243
is_void_v<SubgroupLayoutPV_>,
@@ -287,6 +290,7 @@ struct FMHAConfig {
287290
TensorQ,
288291
TensorK,
289292
TensorV,
293+
TensorO,
290294
GmemTiledCopyQ,
291295
GmemTiledCopyK,
292296
GmemTiledCopyV>;
@@ -320,11 +324,11 @@ struct FMHAConfig {
320324
template <typename chunk_policy, bool Paged, bool Causal, bool Local, bool Sink>
321325
void policy_dispatch_impl(
322326
sycl::queue& queue,
323-
CutlassQKType& cuQKType,
327+
CutlassQKOType& cuQKOType,
324328
const chunk_prefill_args_t& args) {
325329
const int PipelineStages = 2;
326-
if (cuQKType.q_type == CutlassDType::half) {
327-
if (cuQKType.k_type == CutlassDType::half) {
330+
if (cuQKOType.q_type == CutlassDType::half) {
331+
if (cuQKOType.k_type == CutlassDType::half) {
328332
return FMHAConfig<
329333
typename chunk_policy::ShapeQK,
330334
typename chunk_policy::ShapePV,
@@ -340,7 +344,7 @@ void policy_dispatch_impl(
340344
half_t,
341345
half_t,
342346
half_t>::kernel_dispatch(queue, args);
343-
} else if (cuQKType.k_type == CutlassDType::float8_e4m3) {
347+
} else if (cuQKOType.k_type == CutlassDType::float8_e4m3) {
344348
return FMHAConfig<
345349
typename chunk_policy::ShapeQK,
346350
typename chunk_policy::ShapePV,
@@ -356,7 +360,7 @@ void policy_dispatch_impl(
356360
float_e4m3_t,
357361
float_e4m3_t,
358362
half_t>::kernel_dispatch(queue, args);
359-
} else if (cuQKType.k_type == CutlassDType::float8_e5m2) {
363+
} else if (cuQKOType.k_type == CutlassDType::float8_e5m2) {
360364
return FMHAConfig<
361365
typename chunk_policy::ShapeQK,
362366
typename chunk_policy::ShapePV,
@@ -373,8 +377,76 @@ void policy_dispatch_impl(
373377
float_e5m2_t,
374378
half_t>::kernel_dispatch(queue, args);
375379
}
380+
} else if (cuQKOType.q_type == CutlassDType::float8_e4m3) {
381+
if (cuQKOType.o_type == CutlassDType::half) {
382+
return FMHAConfig<
383+
typename chunk_policy::ShapeQK,
384+
typename chunk_policy::ShapePV,
385+
typename chunk_policy::ShapeOut,
386+
typename chunk_policy::SubgroupLayoutQK,
387+
void,
388+
PipelineStages,
389+
Paged,
390+
Causal,
391+
Local,
392+
Sink,
393+
float_e4m3_t,
394+
float_e4m3_t,
395+
float_e4m3_t,
396+
half_t>::kernel_dispatch(queue, args);
397+
} else {
398+
return FMHAConfig<
399+
typename chunk_policy::ShapeQK,
400+
typename chunk_policy::ShapePV,
401+
typename chunk_policy::ShapeOut,
402+
typename chunk_policy::SubgroupLayoutQK,
403+
void,
404+
PipelineStages,
405+
Paged,
406+
Causal,
407+
Local,
408+
Sink,
409+
float_e4m3_t,
410+
float_e4m3_t,
411+
float_e4m3_t,
412+
bfloat16_t>::kernel_dispatch(queue, args);
413+
}
414+
} else if (cuQKOType.q_type == CutlassDType::float8_e5m2) {
415+
if (cuQKOType.o_type == CutlassDType::half) {
416+
return FMHAConfig<
417+
typename chunk_policy::ShapeQK,
418+
typename chunk_policy::ShapePV,
419+
typename chunk_policy::ShapeOut,
420+
typename chunk_policy::SubgroupLayoutQK,
421+
void,
422+
PipelineStages,
423+
Paged,
424+
Causal,
425+
Local,
426+
Sink,
427+
float_e5m2_t,
428+
float_e5m2_t,
429+
float_e5m2_t,
430+
half_t>::kernel_dispatch(queue, args);
431+
} else {
432+
return FMHAConfig<
433+
typename chunk_policy::ShapeQK,
434+
typename chunk_policy::ShapePV,
435+
typename chunk_policy::ShapeOut,
436+
typename chunk_policy::SubgroupLayoutQK,
437+
void,
438+
PipelineStages,
439+
Paged,
440+
Causal,
441+
Local,
442+
Sink,
443+
float_e5m2_t,
444+
float_e5m2_t,
445+
float_e5m2_t,
446+
bfloat16_t>::kernel_dispatch(queue, args);
447+
}
376448
} else {
377-
if (cuQKType.k_type == CutlassDType::bfloat16) {
449+
if (cuQKOType.k_type == CutlassDType::bfloat16) {
378450
return FMHAConfig<
379451
typename chunk_policy::ShapeQK,
380452
typename chunk_policy::ShapePV,
@@ -390,7 +462,7 @@ void policy_dispatch_impl(
390462
bfloat16_t,
391463
bfloat16_t,
392464
bfloat16_t>::kernel_dispatch(queue, args);
393-
} else if (cuQKType.k_type == CutlassDType::float8_e4m3) {
465+
} else if (cuQKOType.k_type == CutlassDType::float8_e4m3) {
394466
return FMHAConfig<
395467
typename chunk_policy::ShapeQK,
396468
typename chunk_policy::ShapePV,
@@ -406,7 +478,7 @@ void policy_dispatch_impl(
406478
float_e4m3_t,
407479
float_e4m3_t,
408480
bfloat16_t>::kernel_dispatch(queue, args);
409-
} else if (cuQKType.k_type == CutlassDType::float8_e5m2) {
481+
} else if (cuQKOType.k_type == CutlassDType::float8_e5m2) {
410482
return FMHAConfig<
411483
typename chunk_policy::ShapeQK,
412484
typename chunk_policy::ShapePV,

csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
extern template void \
2727
policy_dispatch_impl<POLICY, PAGED, CAUSAL, LOCAL, SINK>( \
2828
sycl::queue & queue, \
29-
CutlassQKType & cuQKType, \
29+
CutlassQKOType & cuQKOType, \
3030
const chunk_prefill_args_t& args);
3131

3232
// Generate all 16 bool combinations for a given policy using nested macros

csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using namespace cute;
2121
static_cast<bool>(IMPL_KISLOCAL), \
2222
static_cast<bool>(IMPL_KISSINK)>( \
2323
sycl::queue & queue, \
24-
CutlassQKType& cuQKType, \
24+
CutlassQKOType& cuQKOType, \
2525
const chunk_prefill_args_t& args);
2626

2727
INSTANTIATE_KERNEL()

csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,24 @@ using namespace cute;
55
template <typename chunk_policy, bool... Bs>
66
void policy_dispatch_func(
77
sycl::queue& queue,
8-
CutlassQKType& cuQKType,
8+
CutlassQKOType& cuQKOType,
99
const chunk_prefill_args_t& args) {
10-
policy_dispatch_impl<chunk_policy, Bs...>(queue, cuQKType, args);
10+
policy_dispatch_impl<chunk_policy, Bs...>(queue, cuQKOType, args);
1111
}
1212

1313
template <typename chunk_policy, bool... Bs, typename... Ts>
1414
void policy_dispatch_func(
1515
sycl::queue& queue,
16-
CutlassQKType& cuQKType,
16+
CutlassQKOType& cuQKOType,
1717
const chunk_prefill_args_t& args,
1818
bool b,
1919
Ts... ts) {
2020
if (b) {
2121
policy_dispatch_func<chunk_policy, Bs..., true>(
22-
queue, cuQKType, args, ts...);
22+
queue, cuQKOType, args, ts...);
2323
} else {
2424
policy_dispatch_func<chunk_policy, Bs..., false>(
25-
queue, cuQKType, args, ts...);
25+
queue, cuQKOType, args, ts...);
2626
}
2727
}
2828

@@ -37,6 +37,7 @@ void cutlass_chunk_prefill_impl(
3737
const at::Tensor& cu_seqlens_k,
3838
int max_seqlen_q,
3939
int max_seqlen_k,
40+
std::optional<const at::Tensor>& q_scale,
4041
std::optional<const at::Tensor>& k_scale,
4142
std::optional<const at::Tensor>& v_scale,
4243
double sm_scale,

csrc/xpu/attn/xe_2/collective/chunk_prefill_epilogue.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class FMHAFwdEpilogue {
7777

7878
// softmax sink, same dtype
7979
static constexpr bool Sink = Sink_;
80-
using ElementSink = typename CollectiveMainloop::TensorQ::element_type;
80+
using ElementSink = typename CollectiveMainloop::TensorO::element_type;
8181

8282
// Split k-reduced tiles between participating subgroups.
8383
// Assumption: the A tile is contiguous.

0 commit comments

Comments
 (0)