Skip to content

Commit dcef1c4

Browse files
committed
[Test] Enable max_variable_dim parameter for grouped_gemm
Required oneDNN max_variable_dim support: https://github.com/uxlfoundation/oneDNN/commits/mzhukova/main/max_variable_dim/
1 parent 9bd5207 commit dcef1c4

File tree

1 file changed

+61
-12
lines changed

1 file changed

+61
-12
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_3gemm_swiglu_opt.cpp

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define DEBUG_MOE_LOG 0
1111

1212
#ifdef ENABLE_ONEDNN_FOR_GPU
13+
# include <algorithm>
1314
# include <initializer_list>
1415
# include <oneapi/dnnl/dnnl.hpp>
1516
# include <oneapi/dnnl/dnnl_ocl.hpp>
@@ -1902,7 +1903,7 @@ class moe_3gemm_swiglu_opt_impl : public PrimitiveImplOCL {
19021903
dnnl::memory::desc down_zp_md;
19031904
bool has_zp = false;
19041905
};
1905-
using grouped_kernel_lru = LruCache<int, std::shared_ptr<grouped_onednn_kernel>, std::hash<int>>;
1906+
using grouped_kernel_lru = LruCache<std::pair<int, int>, std::shared_ptr<grouped_onednn_kernel>, PairHash>;
19061907
grouped_kernel_lru _grouped_kernels{128};
19071908
onednn_kernel& get_kernel(int n_token, int expert_no, typed_primitive_inst<moe_3gemm_fused_compressed>& instance) {
19081909
auto key = std::make_pair(n_token, expert_no);
@@ -1964,11 +1965,49 @@ class moe_3gemm_swiglu_opt_impl : public PrimitiveImplOCL {
19641965
return *_kernels.get(key);
19651966
}
19661967

1967-
// Build (and cache) three grouped dnnl::matmul primitives for gate/up/down,
1968-
// keyed by total_gathered_tokens to handle variable-length prefill batches.
1969-
grouped_onednn_kernel& get_grouped_kernel(int total_tokens, typed_primitive_inst<moe_3gemm_fused_compressed>& instance) {
1970-
if (_grouped_kernels.has(total_tokens)) {
1971-
return *_grouped_kernels.get(total_tokens);
1968+
// Quantize max_tokens_per_expert into a bucketed upper bound to limit the
1969+
// number of distinct cached primitives while keeping dispatch-safe values.
1970+
//
1971+
// Strategy adapts bucket granularity to total_tokens:
1972+
// total_tokens <= 128 : no bucketing (host overhead dominates)
1973+
// 128 < total <= 1024 : 4 buckets, bucket_size aligned to 32
1974+
// 1024 < total <= 8192 : 8 buckets, bucket_size aligned to 32
1975+
// total > 8192 : fixed bucket_size = 1024
1976+
// The last bucket always caps at total_tokens to guarantee safety.
1977+
static int bucket_max_variable_dim(int max_tokens_per_expert, int total_tokens) {
1978+
if (max_tokens_per_expert <= 0 || total_tokens <= 0)
1979+
return total_tokens;
1980+
1981+
// Short sequences: host-side overhead dominates, skip bucketing
1982+
if (total_tokens <= 128)
1983+
return total_tokens;
1984+
1985+
int bucket_size;
1986+
if (total_tokens <= 1024) {
1987+
// 4 buckets, first 3 have 32-aligned width, last = total_tokens
1988+
bucket_size = (((total_tokens + 3) / 4) + 31) / 32 * 32;
1989+
} else if (total_tokens <= 8192) {
1990+
// 8 buckets, first 7 have 32-aligned width, last = total_tokens
1991+
bucket_size = (((total_tokens + 7) / 8) + 31) / 32 * 32;
1992+
} else {
1993+
// Fixed 1024-wide buckets for very long sequences
1994+
bucket_size = 1024;
1995+
}
1996+
1997+
// Snap up to bucket ceiling, clamp to total_tokens for the last bucket
1998+
// int bucketed = ((max_tokens_per_expert + bucket_size - 1) / bucket_size) * bucket_size;
1999+
// return std::min(bucketed, total_tokens);
2000+
return std::min(bucket_size, total_tokens);
2001+
}
2002+
2003+
// Build (and cache) three grouped dnnl::matmul primitives for gate/up/down.
2004+
// Cache key is (total_tokens, bucketed_max_variable_dim). The adaptive
2005+
// bucketing limits distinct primitives to 4-10 per total_tokens value.
2006+
grouped_onednn_kernel& get_grouped_kernel(int total_tokens, int max_tokens_per_expert, typed_primitive_inst<moe_3gemm_fused_compressed>& instance) {
2007+
int max_variable_dim = bucket_max_variable_dim(max_tokens_per_expert, total_tokens);
2008+
auto key = std::make_pair(total_tokens, max_variable_dim);
2009+
if (_grouped_kernels.has(key)) {
2010+
return *_grouped_kernels.get(key);
19722011
}
19732012

19742013
auto cur_moe = instance.get_typed_desc<moe_3gemm_fused_compressed>();
@@ -2012,8 +2051,11 @@ class moe_3gemm_swiglu_opt_impl : public PrimitiveImplOCL {
20122051
}
20132052

20142053
// Grouped src/dst: tokens are grouped by expert along axis-0
2015-
auto src_md = dnnl::memory::desc::grouped(dnnl::memory::dims{total_tokens, K}, a_dt, 0, num_experts);
2016-
auto dst_md = dnnl::memory::desc::grouped(dnnl::memory::dims{total_tokens, N}, a_dt, 0, num_experts);
2054+
// max_variable_dim provides a static per-group upper bound for dispatch optimization
2055+
auto src_md =
2056+
dnnl::memory::desc::grouped(dnnl::memory::dims{total_tokens, K}, a_dt, 0, num_experts, dnnl::memory::data_type::s32, max_variable_dim);
2057+
auto dst_md =
2058+
dnnl::memory::desc::grouped(dnnl::memory::dims{total_tokens, N}, a_dt, 0, num_experts, dnnl::memory::data_type::s32, max_variable_dim);
20172059
// Weight: logical [E, K, N], physical layout acb -> stored as [E, N, K]
20182060
auto w_md = dnnl::memory::desc(dnnl::memory::dims{num_experts, K, N}, w_dt, dnnl::memory::format_tag::acb);
20192061

@@ -2051,8 +2093,8 @@ class moe_3gemm_swiglu_opt_impl : public PrimitiveImplOCL {
20512093
if (has_zp)
20522094
gk->down_zp_md = make_quant_md(num_experts, K_d, _down_group_size, N_d, dw_dt);
20532095

2054-
_grouped_kernels.add(total_tokens, gk);
2055-
return *_grouped_kernels.get(total_tokens);
2096+
_grouped_kernels.add(key, gk);
2097+
return *_grouped_kernels.get(key);
20562098
}
20572099

20582100
// inputs 0 is hidden_states, inputs 1 is router_logits[num_tokens, NUM_EXPERTS=128]
@@ -2242,8 +2284,15 @@ class moe_3gemm_swiglu_opt_impl : public PrimitiveImplOCL {
22422284
}
22432285
int total_gathered_tokens = static_cast<int>(token_num) * max_topk;
22442286

2287+
// Compute actual max tokens assigned to any single expert.
2288+
int max_tokens_per_expert = 0;
2289+
if (num_actually_used_experts > 0) {
2290+
max_tokens_per_expert = *std::max_element(tokens_lens_per_expert_cpu.begin(), tokens_lens_per_expert_cpu.begin() + num_actually_used_experts);
2291+
}
2292+
22452293
GPU_DEBUG_TRACE_DETAIL << "\nexec_prefill_grouped_gemm: token_num=" << token_num << ", total_gathered_tokens=" << total_gathered_tokens
2246-
<< ", num_actually_used_experts=" << num_actually_used_experts << std::endl;
2294+
<< ", max_tokens_per_expert=" << max_tokens_per_expert << ", num_actually_used_experts=" << num_actually_used_experts
2295+
<< std::endl;
22472296

22482297
// Upload scratch metadata for the scatter_reduce and gather kernels
22492298
intermediates_memories[MOE_INTERNAL_BUFFER_TOKEN_IDX_PER_EXPERT]
@@ -2290,7 +2339,7 @@ class moe_3gemm_swiglu_opt_impl : public PrimitiveImplOCL {
22902339
// ----------------------------------------------------------------
22912340
// Steps 3-5: OneDNN grouped GEMM – gate, up, SiLU, down
22922341
// ----------------------------------------------------------------
2293-
auto& gk = get_grouped_kernel(total_gathered_tokens, instance);
2342+
auto& gk = get_grouped_kernel(static_cast<int>(token_num), max_tokens_per_expert, instance);
22942343
auto* offsets_ptr = intermediates_memories[MOE_INTERNAL_BUFFER_GROUPED_OFFSETS]->buffer_ptr();
22952344

22962345
// Helper: wrap a flat USM buffer as an OneDNN grouped memory (data + expert row-offsets)

0 commit comments

Comments
 (0)