Skip to content

Commit c8bf0cf

Browse files
committed
explicitly instantiation with dtype combinations
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
1 parent c3df9d2 commit c8bf0cf

File tree

6 files changed

+373
-223
lines changed

6 files changed

+373
-223
lines changed

csrc/xpu/attn/xe_2/chunk_prefill.hpp

Lines changed: 25 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -321,179 +321,31 @@ struct FMHAConfig {
321321
}
322322
};
323323

324-
template <typename chunk_policy, bool Paged, bool Causal, bool Local, bool Sink>
324+
template <
325+
typename chunk_policy,
326+
typename ElementQ,
327+
typename ElementKV,
328+
typename ElementO,
329+
bool Paged,
330+
bool Causal,
331+
bool Local,
332+
bool Sink>
325333
void policy_dispatch_impl(
326-
sycl::queue& queue,
327-
CutlassQKOType& cuQKOType,
328-
const chunk_prefill_args_t& args) {
334+
sycl::queue& queue, const chunk_prefill_args_t& args) {
329335
const int PipelineStages = 2;
330-
if (cuQKOType.q_type == CutlassDType::half) {
331-
if (cuQKOType.k_type == CutlassDType::half) {
332-
return FMHAConfig<
333-
typename chunk_policy::ShapeQK,
334-
typename chunk_policy::ShapePV,
335-
typename chunk_policy::ShapeOut,
336-
typename chunk_policy::SubgroupLayoutQK,
337-
void,
338-
PipelineStages,
339-
Paged,
340-
Causal,
341-
Local,
342-
Sink,
343-
half_t,
344-
half_t,
345-
half_t,
346-
half_t>::kernel_dispatch(queue, args);
347-
} else if (cuQKOType.k_type == CutlassDType::float8_e4m3) {
348-
return FMHAConfig<
349-
typename chunk_policy::ShapeQK,
350-
typename chunk_policy::ShapePV,
351-
typename chunk_policy::ShapeOut,
352-
typename chunk_policy::SubgroupLayoutQK,
353-
void,
354-
PipelineStages,
355-
Paged,
356-
Causal,
357-
Local,
358-
Sink,
359-
half_t,
360-
float_e4m3_t,
361-
float_e4m3_t,
362-
half_t>::kernel_dispatch(queue, args);
363-
} else if (cuQKOType.k_type == CutlassDType::float8_e5m2) {
364-
return FMHAConfig<
365-
typename chunk_policy::ShapeQK,
366-
typename chunk_policy::ShapePV,
367-
typename chunk_policy::ShapeOut,
368-
typename chunk_policy::SubgroupLayoutQK,
369-
void,
370-
PipelineStages,
371-
Paged,
372-
Causal,
373-
Local,
374-
Sink,
375-
half_t,
376-
float_e5m2_t,
377-
float_e5m2_t,
378-
half_t>::kernel_dispatch(queue, args);
379-
}
380-
} else if (cuQKOType.q_type == CutlassDType::float8_e4m3) {
381-
if (cuQKOType.o_type == CutlassDType::half) {
382-
return FMHAConfig<
383-
typename chunk_policy::ShapeQK,
384-
typename chunk_policy::ShapePV,
385-
typename chunk_policy::ShapeOut,
386-
typename chunk_policy::SubgroupLayoutQK,
387-
void,
388-
PipelineStages,
389-
Paged,
390-
Causal,
391-
Local,
392-
Sink,
393-
float_e4m3_t,
394-
float_e4m3_t,
395-
float_e4m3_t,
396-
half_t>::kernel_dispatch(queue, args);
397-
} else {
398-
return FMHAConfig<
399-
typename chunk_policy::ShapeQK,
400-
typename chunk_policy::ShapePV,
401-
typename chunk_policy::ShapeOut,
402-
typename chunk_policy::SubgroupLayoutQK,
403-
void,
404-
PipelineStages,
405-
Paged,
406-
Causal,
407-
Local,
408-
Sink,
409-
float_e4m3_t,
410-
float_e4m3_t,
411-
float_e4m3_t,
412-
bfloat16_t>::kernel_dispatch(queue, args);
413-
}
414-
} else if (cuQKOType.q_type == CutlassDType::float8_e5m2) {
415-
if (cuQKOType.o_type == CutlassDType::half) {
416-
return FMHAConfig<
417-
typename chunk_policy::ShapeQK,
418-
typename chunk_policy::ShapePV,
419-
typename chunk_policy::ShapeOut,
420-
typename chunk_policy::SubgroupLayoutQK,
421-
void,
422-
PipelineStages,
423-
Paged,
424-
Causal,
425-
Local,
426-
Sink,
427-
float_e5m2_t,
428-
float_e5m2_t,
429-
float_e5m2_t,
430-
half_t>::kernel_dispatch(queue, args);
431-
} else {
432-
return FMHAConfig<
433-
typename chunk_policy::ShapeQK,
434-
typename chunk_policy::ShapePV,
435-
typename chunk_policy::ShapeOut,
436-
typename chunk_policy::SubgroupLayoutQK,
437-
void,
438-
PipelineStages,
439-
Paged,
440-
Causal,
441-
Local,
442-
Sink,
443-
float_e5m2_t,
444-
float_e5m2_t,
445-
float_e5m2_t,
446-
bfloat16_t>::kernel_dispatch(queue, args);
447-
}
448-
} else {
449-
if (cuQKOType.k_type == CutlassDType::bfloat16) {
450-
return FMHAConfig<
451-
typename chunk_policy::ShapeQK,
452-
typename chunk_policy::ShapePV,
453-
typename chunk_policy::ShapeOut,
454-
typename chunk_policy::SubgroupLayoutQK,
455-
void,
456-
PipelineStages,
457-
Paged,
458-
Causal,
459-
Local,
460-
Sink,
461-
bfloat16_t,
462-
bfloat16_t,
463-
bfloat16_t,
464-
bfloat16_t>::kernel_dispatch(queue, args);
465-
} else if (cuQKOType.k_type == CutlassDType::float8_e4m3) {
466-
return FMHAConfig<
467-
typename chunk_policy::ShapeQK,
468-
typename chunk_policy::ShapePV,
469-
typename chunk_policy::ShapeOut,
470-
typename chunk_policy::SubgroupLayoutQK,
471-
void,
472-
PipelineStages,
473-
Paged,
474-
Causal,
475-
Local,
476-
Sink,
477-
bfloat16_t,
478-
float_e4m3_t,
479-
float_e4m3_t,
480-
bfloat16_t>::kernel_dispatch(queue, args);
481-
} else if (cuQKOType.k_type == CutlassDType::float8_e5m2) {
482-
return FMHAConfig<
483-
typename chunk_policy::ShapeQK,
484-
typename chunk_policy::ShapePV,
485-
typename chunk_policy::ShapeOut,
486-
typename chunk_policy::SubgroupLayoutQK,
487-
void,
488-
PipelineStages,
489-
Paged,
490-
Causal,
491-
Local,
492-
Sink,
493-
bfloat16_t,
494-
float_e5m2_t,
495-
float_e5m2_t,
496-
bfloat16_t>::kernel_dispatch(queue, args);
497-
}
498-
}
336+
return FMHAConfig<
337+
typename chunk_policy::ShapeQK,
338+
typename chunk_policy::ShapePV,
339+
typename chunk_policy::ShapeOut,
340+
typename chunk_policy::SubgroupLayoutQK,
341+
void,
342+
PipelineStages,
343+
Paged,
344+
Causal,
345+
Local,
346+
Sink,
347+
ElementQ,
348+
ElementKV,
349+
ElementKV,
350+
ElementO>::kernel_dispatch(queue, args);
499351
}

csrc/xpu/attn/xe_2/chunk_prefill_configure.cmake

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,51 @@
11
function(fmha_forward_configure FILENAME_SUFFIX)
22
set(GEN_KERNEL_SRCS) # output
3-
set(L_TYPES "fp16" "bf16")
43
set(L_BOOLS "false" "true")
54
set(BOOL_FLAG_false "f")
65
set(BOOL_FLAG_true "t")
76
set(policy_list
87
"chunk_policy_head64" "chunk_policy_head96" "chunk_policy_head128"
98
"chunk_policy_head192" "chunk_policy_head256")
109

11-
set(IMPL_KV_T "fp16")
10+
# Allowed dtype combinations must match runtime dispatch constraints. Format:
11+
# Q_TYPE|KV_TYPE|O_TYPE|FILE_TAG
12+
set(dtype_combo_list
13+
"half_t|half_t|half_t|h_h_h"
14+
"half_t|float_e4m3_t|half_t|h_e4_h"
15+
"half_t|float_e5m2_t|half_t|h_e5_h"
16+
"bfloat16_t|bfloat16_t|bfloat16_t|b_b_b"
17+
"bfloat16_t|float_e4m3_t|bfloat16_t|b_e4_b"
18+
"bfloat16_t|float_e5m2_t|bfloat16_t|b_e5_b"
19+
"float_e4m3_t|float_e4m3_t|half_t|e4_e4_h"
20+
"float_e4m3_t|float_e4m3_t|bfloat16_t|e4_e4_b"
21+
"float_e5m2_t|float_e5m2_t|half_t|e5_e5_h"
22+
"float_e5m2_t|float_e5m2_t|bfloat16_t|e5_e5_b")
1223

1324
foreach(IMPL_POLICY ${policy_list})
14-
# foreach(IMPL_T ${L_TYPES})
15-
foreach(IMPL_KISPAGED ${L_BOOLS})
16-
foreach(IMPL_KISCAUSAL ${L_BOOLS})
17-
foreach(IMPL_KISLOCAL ${L_BOOLS})
18-
foreach(IMPL_KISSINK ${L_BOOLS})
19-
set(FILE_SUFFIX "${IMPL_POLICY}_")
20-
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISPAGED}}")
21-
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}")
22-
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}")
23-
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}")
24-
configure_file(${FILENAME_SUFFIX}.cpp.in
25-
"${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp")
26-
list(
27-
APPEND
28-
GEN_KERNEL_SRCS
29-
"${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp"
30-
)
25+
foreach(dtype_combo ${dtype_combo_list})
26+
string(REPLACE "|" ";" dtype_parts "${dtype_combo}")
27+
list(GET dtype_parts 0 IMPL_Q_T)
28+
list(GET dtype_parts 1 IMPL_KV_T)
29+
list(GET dtype_parts 2 IMPL_O_T)
30+
list(GET dtype_parts 3 DTYPE_TAG)
31+
32+
foreach(IMPL_KISPAGED ${L_BOOLS})
33+
foreach(IMPL_KISCAUSAL ${L_BOOLS})
34+
foreach(IMPL_KISLOCAL ${L_BOOLS})
35+
foreach(IMPL_KISSINK ${L_BOOLS})
36+
set(FILE_SUFFIX "${IMPL_POLICY}_${DTYPE_TAG}_")
37+
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISPAGED}}")
38+
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}")
39+
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}")
40+
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}")
41+
configure_file(${FILENAME_SUFFIX}.cpp.in
42+
"${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp")
43+
list(
44+
APPEND
45+
GEN_KERNEL_SRCS
46+
"${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp"
47+
)
48+
endforeach()
3149
endforeach()
3250
endforeach()
3351
endforeach()

csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,77 @@
2121
// that include chunk_prefill.hpp. Each specialization is explicitly
2222
// instantiated in its own .cpp file generated by CMake.
2323

24-
// Helper macro to declare a single extern template instantiation
25-
#define DECLARE_POLICY_DISPATCH_EXTERN(POLICY, PAGED, CAUSAL, LOCAL, SINK) \
26-
extern template void \
27-
policy_dispatch_impl<POLICY, PAGED, CAUSAL, LOCAL, SINK>( \
28-
sycl::queue & queue, \
29-
CutlassQKOType & cuQKOType, \
30-
const chunk_prefill_args_t& args);
24+
// Helper macro to declare a single extern template instantiation.
25+
// Template order is: policy, ElementQ, ElementKV, ElementO, bools...
26+
#define DECLARE_POLICY_DISPATCH_EXTERN( \
27+
POLICY, Q_T, KV_T, O_T, PAGED, CAUSAL, LOCAL, SINK) \
28+
extern template void \
29+
policy_dispatch_impl<POLICY, Q_T, KV_T, O_T, PAGED, CAUSAL, LOCAL, SINK>( \
30+
sycl::queue & queue, const chunk_prefill_args_t& args);
31+
32+
// Allowed dtype combinations (must match runtime dispatch constraints):
33+
// 1) Q=half -> O=half, KV in {half, fp8_e4m3, fp8_e5m2}
34+
// 2) Q=bf16 -> O=bf16, KV in {bf16, fp8_e4m3, fp8_e5m2}
35+
// 3) Q=fp8_e4m3-> O in {half,bf16}, KV=fp8_e4m3
36+
// 4) Q=fp8_e5m2-> O in {half,bf16}, KV=fp8_e5m2
37+
#define DECLARE_ALLOWED_DTYPES(POLICY, PAGED, CAUSAL, LOCAL, SINK) \
38+
DECLARE_POLICY_DISPATCH_EXTERN( \
39+
POLICY, half_t, half_t, half_t, PAGED, CAUSAL, LOCAL, SINK) \
40+
DECLARE_POLICY_DISPATCH_EXTERN( \
41+
POLICY, half_t, float_e4m3_t, half_t, PAGED, CAUSAL, LOCAL, SINK) \
42+
DECLARE_POLICY_DISPATCH_EXTERN( \
43+
POLICY, half_t, float_e5m2_t, half_t, PAGED, CAUSAL, LOCAL, SINK) \
44+
DECLARE_POLICY_DISPATCH_EXTERN( \
45+
POLICY, bfloat16_t, bfloat16_t, bfloat16_t, PAGED, CAUSAL, LOCAL, SINK) \
46+
DECLARE_POLICY_DISPATCH_EXTERN( \
47+
POLICY, \
48+
bfloat16_t, \
49+
float_e4m3_t, \
50+
bfloat16_t, \
51+
PAGED, \
52+
CAUSAL, \
53+
LOCAL, \
54+
SINK) \
55+
DECLARE_POLICY_DISPATCH_EXTERN( \
56+
POLICY, \
57+
bfloat16_t, \
58+
float_e5m2_t, \
59+
bfloat16_t, \
60+
PAGED, \
61+
CAUSAL, \
62+
LOCAL, \
63+
SINK) \
64+
DECLARE_POLICY_DISPATCH_EXTERN( \
65+
POLICY, float_e4m3_t, float_e4m3_t, half_t, PAGED, CAUSAL, LOCAL, SINK) \
66+
DECLARE_POLICY_DISPATCH_EXTERN( \
67+
POLICY, \
68+
float_e4m3_t, \
69+
float_e4m3_t, \
70+
bfloat16_t, \
71+
PAGED, \
72+
CAUSAL, \
73+
LOCAL, \
74+
SINK) \
75+
DECLARE_POLICY_DISPATCH_EXTERN( \
76+
POLICY, float_e5m2_t, float_e5m2_t, half_t, PAGED, CAUSAL, LOCAL, SINK) \
77+
DECLARE_POLICY_DISPATCH_EXTERN( \
78+
POLICY, \
79+
float_e5m2_t, \
80+
float_e5m2_t, \
81+
bfloat16_t, \
82+
PAGED, \
83+
CAUSAL, \
84+
LOCAL, \
85+
SINK)
3186

3287
// Generate all 16 bool combinations for a given policy using nested macros
3388
// Pattern: Paged, Causal, Local, Sink (all permutations of 4 bools = 2^4 = 16)
3489
// This hierarchical approach makes it easy to extend to more bool parameters
3590

3691
// Level 4: Iterate over Sink values (innermost)
37-
#define DECLARE_FOR_SINK(POLICY, PAGED, CAUSAL, LOCAL) \
38-
DECLARE_POLICY_DISPATCH_EXTERN(POLICY, PAGED, CAUSAL, LOCAL, false) \
39-
DECLARE_POLICY_DISPATCH_EXTERN(POLICY, PAGED, CAUSAL, LOCAL, true)
92+
#define DECLARE_FOR_SINK(POLICY, PAGED, CAUSAL, LOCAL) \
93+
DECLARE_ALLOWED_DTYPES(POLICY, PAGED, CAUSAL, LOCAL, false) \
94+
DECLARE_ALLOWED_DTYPES(POLICY, PAGED, CAUSAL, LOCAL, true)
4095

4196
// Level 3: Iterate over Local values
4297
#define DECLARE_FOR_LOCAL(POLICY, PAGED, CAUSAL) \
@@ -61,4 +116,5 @@ CHUNK_POLICY_LIST(DECLARE_ALL_BOOL_COMBINATIONS)
61116
#undef DECLARE_FOR_CAUSAL
62117
#undef DECLARE_FOR_LOCAL
63118
#undef DECLARE_FOR_SINK
119+
#undef DECLARE_ALLOWED_DTYPES
64120
#undef DECLARE_POLICY_DISPATCH_EXTERN

csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ using namespace cute;
44

55
// clang-format off
66
// macros to be filled in CMake
7-
#define IMPL_T ${IMPL_T}
7+
#define IMPL_Q_T ${IMPL_Q_T}
88
#define IMPL_KV_T ${IMPL_KV_T}
9+
#define IMPL_O_T ${IMPL_O_T}
910
#define IMPL_POLICY ${IMPL_POLICY}
1011
#cmakedefine01 IMPL_KISPAGED
1112
#cmakedefine01 IMPL_KISCAUSAL
@@ -16,12 +17,14 @@ using namespace cute;
1617
#define INSTANTIATE_KERNEL() \
1718
template void policy_dispatch_impl< \
1819
IMPL_POLICY, \
20+
IMPL_Q_T, \
21+
IMPL_KV_T, \
22+
IMPL_O_T, \
1923
static_cast<bool>(IMPL_KISPAGED), \
2024
static_cast<bool>(IMPL_KISCAUSAL), \
2125
static_cast<bool>(IMPL_KISLOCAL), \
2226
static_cast<bool>(IMPL_KISSINK)>( \
2327
sycl::queue & queue, \
24-
CutlassQKOType& cuQKOType, \
2528
const chunk_prefill_args_t& args);
2629

2730
INSTANTIATE_KERNEL()

0 commit comments

Comments
 (0)