Skip to content

Commit 5210f8d

Browse files
[Cherry-Pick][CI] Increase the shape of w4afp8 gemm(#5957) (#5948)
* 增加w4afp8 gemm shape * 增加w4afp8 shape * code style * code style
1 parent c2ad0a9 commit 5210f8d

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
import os
1515
import re
1616

17-
file_dir = "./gpu_ops/w4afp8_gemm/"
17+
script_dir = os.path.dirname(os.path.abspath(__file__))
18+
file_dir = os.path.join(script_dir, "..", "gpu_ops", "w4afp8_gemm") + os.sep
1819

1920
gemm_template_head = """
2021
#pragma once
@@ -85,7 +86,30 @@
8586
"""
8687

8788
# [M, K, Number of experts, token Padding Size, weight K group size]
88-
gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128]]
89+
gemm_case = [
90+
[256, 256, 2, 0, 128],
91+
[512, 256, 2, 0, 128],
92+
[7168, 7168, 6, 8192, 128], # num_max_dispatch_tokens_per_rank=128
93+
[7168, 3584, 6, 8192, 128], # num_max_dispatch_tokens_per_rank=128
94+
[7168, 7168, 6, 10240, 128], # num_max_dispatch_tokens_per_rank=160
95+
[7168, 3584, 6, 10240, 128], # num_max_dispatch_tokens_per_rank=160
96+
[7168, 7168, 6, 12288, 128], # num_max_dispatch_tokens_per_rank=192
97+
[7168, 3584, 6, 12288, 128], # num_max_dispatch_tokens_per_rank=192
98+
[7168, 7168, 6, 16384, 128], # num_max_dispatch_tokens_per_rank=256
99+
[7168, 3584, 6, 16384, 128], # num_max_dispatch_tokens_per_rank=256
100+
[7168, 7168, 6, 20480, 128], # num_max_dispatch_tokens_per_rank=320
101+
[7168, 3584, 6, 20480, 128], # num_max_dispatch_tokens_per_rank=320
102+
[7168, 7168, 7, 8192, 128], # num_max_dispatch_tokens_per_rank=128
103+
[7168, 3584, 7, 8192, 128], # num_max_dispatch_tokens_per_rank=128
104+
[7168, 7168, 7, 10240, 128], # num_max_dispatch_tokens_per_rank=160
105+
[7168, 3584, 7, 10240, 128], # num_max_dispatch_tokens_per_rank=160
106+
[7168, 7168, 7, 12288, 128], # num_max_dispatch_tokens_per_rank=192
107+
[7168, 3584, 7, 12288, 128], # num_max_dispatch_tokens_per_rank=192
108+
[7168, 7168, 7, 16384, 128], # num_max_dispatch_tokens_per_rank=256
109+
[7168, 3584, 7, 16384, 128], # num_max_dispatch_tokens_per_rank=256
110+
[7168, 7168, 7, 20480, 128], # num_max_dispatch_tokens_per_rank=320
111+
[7168, 3584, 7, 20480, 128], # num_max_dispatch_tokens_per_rank=320
112+
]
89113

90114
dtype = ["BF16"]
91115

0 commit comments

Comments
 (0)