@@ -36,6 +36,7 @@ struct chunk_prefill_args_t {
3636 int max_keys;
3737 int total_seqlen_q;
3838 int total_seqlen_k;
39+ void * q_scale;
3940 void * k_scale;
4041 void * v_scale;
4142 float sm_scale;
@@ -145,8 +146,9 @@ struct KernelLauncher {
145146 stride_V,
146147 reinterpret_cast <ElementO*>(args.out ),
147148 stride_O,
148- reinterpret_cast <ElementQ *>(args.sm_sink )},
149+ reinterpret_cast <ElementO *>(args.sm_sink )},
149150 {args.sm_scale ,
151+ args.q_scale ,
150152 args.k_scale ,
151153 args.v_scale ,
152154 static_cast <int *>(args.block_table ),
@@ -232,9 +234,10 @@ template <
232234struct FMHAConfig {
233235 static constexpr int SGTileQ =
234236 get<0 >(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})))();
237+ // Note that always use output dtype for MMAOperation
235238 using MMAOperation = cute::conditional_t <
236239 is_void_v<MMAOperation_>,
237- XE_DPAS_TT<cute::gcd(SGTileQ, 8 ), float , ElementQ >,
240+ XE_DPAS_TT<cute::gcd(SGTileQ, 8 ), float , ElementO >,
238241 MMAOperation_>;
239242 using SubgroupLayoutPV = cute::conditional_t <
240243 is_void_v<SubgroupLayoutPV_>,
@@ -287,6 +290,7 @@ struct FMHAConfig {
287290 TensorQ,
288291 TensorK,
289292 TensorV,
293+ TensorO,
290294 GmemTiledCopyQ,
291295 GmemTiledCopyK,
292296 GmemTiledCopyV>;
@@ -320,11 +324,11 @@ struct FMHAConfig {
320324template <typename chunk_policy, bool Paged, bool Causal, bool Local, bool Sink>
321325void policy_dispatch_impl (
322326 sycl::queue& queue,
323- CutlassQKType& cuQKType ,
327+ CutlassQKOType& cuQKOType ,
324328 const chunk_prefill_args_t & args) {
325329 const int PipelineStages = 2 ;
326- if (cuQKType .q_type == CutlassDType::half) {
327- if (cuQKType .k_type == CutlassDType::half) {
330+ if (cuQKOType .q_type == CutlassDType::half) {
331+ if (cuQKOType .k_type == CutlassDType::half) {
328332 return FMHAConfig<
329333 typename chunk_policy::ShapeQK,
330334 typename chunk_policy::ShapePV,
@@ -340,7 +344,7 @@ void policy_dispatch_impl(
340344 half_t ,
341345 half_t ,
342346 half_t >::kernel_dispatch (queue, args);
343- } else if (cuQKType .k_type == CutlassDType::float8_e4m3) {
347+ } else if (cuQKOType .k_type == CutlassDType::float8_e4m3) {
344348 return FMHAConfig<
345349 typename chunk_policy::ShapeQK,
346350 typename chunk_policy::ShapePV,
@@ -356,7 +360,7 @@ void policy_dispatch_impl(
356360 float_e4m3_t ,
357361 float_e4m3_t ,
358362 half_t >::kernel_dispatch (queue, args);
359- } else if (cuQKType .k_type == CutlassDType::float8_e5m2) {
363+ } else if (cuQKOType .k_type == CutlassDType::float8_e5m2) {
360364 return FMHAConfig<
361365 typename chunk_policy::ShapeQK,
362366 typename chunk_policy::ShapePV,
@@ -373,8 +377,76 @@ void policy_dispatch_impl(
373377 float_e5m2_t ,
374378 half_t >::kernel_dispatch (queue, args);
375379 }
380+ } else if (cuQKOType.q_type == CutlassDType::float8_e4m3) {
381+ if (cuQKOType.o_type == CutlassDType::half) {
382+ return FMHAConfig<
383+ typename chunk_policy::ShapeQK,
384+ typename chunk_policy::ShapePV,
385+ typename chunk_policy::ShapeOut,
386+ typename chunk_policy::SubgroupLayoutQK,
387+ void ,
388+ PipelineStages,
389+ Paged,
390+ Causal,
391+ Local,
392+ Sink,
393+ float_e4m3_t ,
394+ float_e4m3_t ,
395+ float_e4m3_t ,
396+ half_t >::kernel_dispatch (queue, args);
397+ } else {
398+ return FMHAConfig<
399+ typename chunk_policy::ShapeQK,
400+ typename chunk_policy::ShapePV,
401+ typename chunk_policy::ShapeOut,
402+ typename chunk_policy::SubgroupLayoutQK,
403+ void ,
404+ PipelineStages,
405+ Paged,
406+ Causal,
407+ Local,
408+ Sink,
409+ float_e4m3_t ,
410+ float_e4m3_t ,
411+ float_e4m3_t ,
412+ bfloat16_t >::kernel_dispatch (queue, args);
413+ }
414+ } else if (cuQKOType.q_type == CutlassDType::float8_e5m2) {
415+ if (cuQKOType.o_type == CutlassDType::half) {
416+ return FMHAConfig<
417+ typename chunk_policy::ShapeQK,
418+ typename chunk_policy::ShapePV,
419+ typename chunk_policy::ShapeOut,
420+ typename chunk_policy::SubgroupLayoutQK,
421+ void ,
422+ PipelineStages,
423+ Paged,
424+ Causal,
425+ Local,
426+ Sink,
427+ float_e5m2_t ,
428+ float_e5m2_t ,
429+ float_e5m2_t ,
430+ half_t >::kernel_dispatch (queue, args);
431+ } else {
432+ return FMHAConfig<
433+ typename chunk_policy::ShapeQK,
434+ typename chunk_policy::ShapePV,
435+ typename chunk_policy::ShapeOut,
436+ typename chunk_policy::SubgroupLayoutQK,
437+ void ,
438+ PipelineStages,
439+ Paged,
440+ Causal,
441+ Local,
442+ Sink,
443+ float_e5m2_t ,
444+ float_e5m2_t ,
445+ float_e5m2_t ,
446+ bfloat16_t >::kernel_dispatch (queue, args);
447+ }
376448 } else {
377- if (cuQKType .k_type == CutlassDType::bfloat16) {
449+ if (cuQKOType .k_type == CutlassDType::bfloat16) {
378450 return FMHAConfig<
379451 typename chunk_policy::ShapeQK,
380452 typename chunk_policy::ShapePV,
@@ -390,7 +462,7 @@ void policy_dispatch_impl(
390462 bfloat16_t ,
391463 bfloat16_t ,
392464 bfloat16_t >::kernel_dispatch (queue, args);
393- } else if (cuQKType .k_type == CutlassDType::float8_e4m3) {
465+ } else if (cuQKOType .k_type == CutlassDType::float8_e4m3) {
394466 return FMHAConfig<
395467 typename chunk_policy::ShapeQK,
396468 typename chunk_policy::ShapePV,
@@ -406,7 +478,7 @@ void policy_dispatch_impl(
406478 float_e4m3_t ,
407479 float_e4m3_t ,
408480 bfloat16_t >::kernel_dispatch (queue, args);
409- } else if (cuQKType .k_type == CutlassDType::float8_e5m2) {
481+ } else if (cuQKOType .k_type == CutlassDType::float8_e5m2) {
410482 return FMHAConfig<
411483 typename chunk_policy::ShapeQK,
412484 typename chunk_policy::ShapePV,
0 commit comments