|
10 | 10 | #define DEBUG_MOE_LOG 0 |
11 | 11 |
|
12 | 12 | #ifdef ENABLE_ONEDNN_FOR_GPU |
| 13 | +# include <algorithm> |
13 | 14 | # include <initializer_list> |
14 | 15 | # include <oneapi/dnnl/dnnl.hpp> |
15 | 16 | # include <oneapi/dnnl/dnnl_ocl.hpp> |
@@ -1902,7 +1903,7 @@ class moe_3gemm_swiglu_opt_impl : public PrimitiveImplOCL { |
1902 | 1903 | dnnl::memory::desc down_zp_md; |
1903 | 1904 | bool has_zp = false; |
1904 | 1905 | }; |
1905 | | - using grouped_kernel_lru = LruCache<int, std::shared_ptr<grouped_onednn_kernel>, std::hash<int>>; |
| 1906 | + using grouped_kernel_lru = LruCache<std::pair<int, int>, std::shared_ptr<grouped_onednn_kernel>, PairHash>; |
1906 | 1907 | grouped_kernel_lru _grouped_kernels{128}; |
1907 | 1908 | onednn_kernel& get_kernel(int n_token, int expert_no, typed_primitive_inst<moe_3gemm_fused_compressed>& instance) { |
1908 | 1909 | auto key = std::make_pair(n_token, expert_no); |
@@ -1964,11 +1965,49 @@ class moe_3gemm_swiglu_opt_impl : public PrimitiveImplOCL { |
1964 | 1965 | return *_kernels.get(key); |
1965 | 1966 | } |
1966 | 1967 |
|
1967 | | - // Build (and cache) three grouped dnnl::matmul primitives for gate/up/down, |
1968 | | - // keyed by total_gathered_tokens to handle variable-length prefill batches. |
1969 | | - grouped_onednn_kernel& get_grouped_kernel(int total_tokens, typed_primitive_inst<moe_3gemm_fused_compressed>& instance) { |
1970 | | - if (_grouped_kernels.has(total_tokens)) { |
1971 | | - return *_grouped_kernels.get(total_tokens); |
| 1968 | + // Quantize max_tokens_per_expert into a bucketed upper bound to limit the |
| 1969 | + // number of distinct cached primitives while keeping dispatch-safe values. |
| 1970 | + // |
| 1971 | + // Strategy adapts bucket granularity to total_tokens: |
| 1972 | + // total_tokens <= 128 : no bucketing (host overhead dominates) |
| 1973 | + // 128 < total <= 1024 : 4 buckets, bucket_size aligned to 32 |
| 1974 | + // 1024 < total <= 8192 : 8 buckets, bucket_size aligned to 32 |
| 1975 | + // total > 8192 : fixed bucket_size = 1024 |
| 1976 | + // The last bucket always caps at total_tokens to guarantee safety. |
| 1977 | + static int bucket_max_variable_dim(int max_tokens_per_expert, int total_tokens) { |
| 1978 | + if (max_tokens_per_expert <= 0 || total_tokens <= 0) |
| 1979 | + return total_tokens; |
| 1980 | + |
| 1981 | + // Short sequences: host-side overhead dominates, skip bucketing |
| 1982 | + if (total_tokens <= 128) |
| 1983 | + return total_tokens; |
| 1984 | + |
| 1985 | + int bucket_size; |
| 1986 | + if (total_tokens <= 1024) { |
| 1987 | + // 4 buckets, first 3 have 32-aligned width, last = total_tokens |
| 1988 | + bucket_size = (((total_tokens + 3) / 4) + 31) / 32 * 32; |
| 1989 | + } else if (total_tokens <= 8192) { |
| 1990 | + // 8 buckets, first 7 have 32-aligned width, last = total_tokens |
| 1991 | + bucket_size = (((total_tokens + 7) / 8) + 31) / 32 * 32; |
| 1992 | + } else { |
| 1993 | + // Fixed 1024-wide buckets for very long sequences |
| 1994 | + bucket_size = 1024; |
| 1995 | + } |
| 1996 | + |
| 1997 | + // Snap up to bucket ceiling, clamp to total_tokens for the last bucket |
| 1998 | + // int bucketed = ((max_tokens_per_expert + bucket_size - 1) / bucket_size) * bucket_size; |
| 1999 | + // return std::min(bucketed, total_tokens); |
| 2000 | + return std::min(bucket_size, total_tokens); |
| 2001 | + } |
| 2002 | + |
| 2003 | + // Build (and cache) three grouped dnnl::matmul primitives for gate/up/down. |
| 2004 | + // Cache key is (total_tokens, bucketed_max_variable_dim). The adaptive |
| 2005 | + // bucketing limits distinct primitives to 4-10 per total_tokens value. |
| 2006 | + grouped_onednn_kernel& get_grouped_kernel(int total_tokens, int max_tokens_per_expert, typed_primitive_inst<moe_3gemm_fused_compressed>& instance) { |
| 2007 | + int max_variable_dim = bucket_max_variable_dim(max_tokens_per_expert, total_tokens); |
| 2008 | + auto key = std::make_pair(total_tokens, max_variable_dim); |
| 2009 | + if (_grouped_kernels.has(key)) { |
| 2010 | + return *_grouped_kernels.get(key); |
1972 | 2011 | } |
1973 | 2012 |
|
1974 | 2013 | auto cur_moe = instance.get_typed_desc<moe_3gemm_fused_compressed>(); |
@@ -2012,8 +2051,11 @@ class moe_3gemm_swiglu_opt_impl : public PrimitiveImplOCL { |
2012 | 2051 | } |
2013 | 2052 |
|
2014 | 2053 | // Grouped src/dst: tokens are grouped by expert along axis-0 |
2015 | | - auto src_md = dnnl::memory::desc::grouped(dnnl::memory::dims{total_tokens, K}, a_dt, 0, num_experts); |
2016 | | - auto dst_md = dnnl::memory::desc::grouped(dnnl::memory::dims{total_tokens, N}, a_dt, 0, num_experts); |
| 2054 | + // max_variable_dim provides a static per-group upper bound for dispatch optimization |
| 2055 | + auto src_md = |
| 2056 | + dnnl::memory::desc::grouped(dnnl::memory::dims{total_tokens, K}, a_dt, 0, num_experts, dnnl::memory::data_type::s32, max_variable_dim); |
| 2057 | + auto dst_md = |
| 2058 | + dnnl::memory::desc::grouped(dnnl::memory::dims{total_tokens, N}, a_dt, 0, num_experts, dnnl::memory::data_type::s32, max_variable_dim); |
2017 | 2059 | // Weight: logical [E, K, N], physical layout acb -> stored as [E, N, K] |
2018 | 2060 | auto w_md = dnnl::memory::desc(dnnl::memory::dims{num_experts, K, N}, w_dt, dnnl::memory::format_tag::acb); |
2019 | 2061 |
|
@@ -2051,8 +2093,8 @@ class moe_3gemm_swiglu_opt_impl : public PrimitiveImplOCL { |
2051 | 2093 | if (has_zp) |
2052 | 2094 | gk->down_zp_md = make_quant_md(num_experts, K_d, _down_group_size, N_d, dw_dt); |
2053 | 2095 |
|
2054 | | - _grouped_kernels.add(total_tokens, gk); |
2055 | | - return *_grouped_kernels.get(total_tokens); |
| 2096 | + _grouped_kernels.add(key, gk); |
| 2097 | + return *_grouped_kernels.get(key); |
2056 | 2098 | } |
2057 | 2099 |
|
2058 | 2100 | // inputs 0 is hidden_states, inputs 1 is router_logits[num_tokens, NUM_EXPERTS=128] |
@@ -2242,8 +2284,15 @@ class moe_3gemm_swiglu_opt_impl : public PrimitiveImplOCL { |
2242 | 2284 | } |
2243 | 2285 | int total_gathered_tokens = static_cast<int>(token_num) * max_topk; |
2244 | 2286 |
|
| 2287 | + // Compute actual max tokens assigned to any single expert. |
| 2288 | + int max_tokens_per_expert = 0; |
| 2289 | + if (num_actually_used_experts > 0) { |
| 2290 | + max_tokens_per_expert = *std::max_element(tokens_lens_per_expert_cpu.begin(), tokens_lens_per_expert_cpu.begin() + num_actually_used_experts); |
| 2291 | + } |
| 2292 | + |
2245 | 2293 | GPU_DEBUG_TRACE_DETAIL << "\nexec_prefill_grouped_gemm: token_num=" << token_num << ", total_gathered_tokens=" << total_gathered_tokens |
2246 | | - << ", num_actually_used_experts=" << num_actually_used_experts << std::endl; |
| 2294 | + << ", max_tokens_per_expert=" << max_tokens_per_expert << ", num_actually_used_experts=" << num_actually_used_experts |
| 2295 | + << std::endl; |
2247 | 2296 |
|
2248 | 2297 | // Upload scratch metadata for the scatter_reduce and gather kernels |
2249 | 2298 | intermediates_memories[MOE_INTERNAL_BUFFER_TOKEN_IDX_PER_EXPERT] |
@@ -2290,7 +2339,7 @@ class moe_3gemm_swiglu_opt_impl : public PrimitiveImplOCL { |
2290 | 2339 | // ---------------------------------------------------------------- |
2291 | 2340 | // Steps 3-5: OneDNN grouped GEMM – gate, up, SiLU, down |
2292 | 2341 | // ---------------------------------------------------------------- |
2293 | | - auto& gk = get_grouped_kernel(total_gathered_tokens, instance); |
| 2342 | + auto& gk = get_grouped_kernel(static_cast<int>(token_num), max_tokens_per_expert, instance); |
2294 | 2343 | auto* offsets_ptr = intermediates_memories[MOE_INTERNAL_BUFFER_GROUPED_OFFSETS]->buffer_ptr(); |
2295 | 2344 |
|
2296 | 2345 | // Helper: wrap a flat USM buffer as an OneDNN grouped memory (data + expert row-offsets) |
|
0 commit comments