|
14 | 14 | import os |
15 | 15 | import re |
16 | 16 |
|
17 | | -file_dir = "./gpu_ops/w4afp8_gemm/" |
| 17 | +script_dir = os.path.dirname(os.path.abspath(__file__)) |
| 18 | +file_dir = os.path.join(script_dir, "..", "gpu_ops", "w4afp8_gemm") + os.sep |
18 | 19 |
|
19 | 20 | gemm_template_head = """ |
20 | 21 | #pragma once |
|
85 | 86 | """ |
86 | 87 |
|
87 | 88 | # [M, K, Number of experts, token Padding Size, weight K group size] |
88 | | -gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128]] |
| 89 | +gemm_case = [ |
| 90 | + [256, 256, 2, 0, 128], |
| 91 | + [512, 256, 2, 0, 128], |
| 92 | + [7168, 7168, 6, 8192, 128], # num_max_dispatch_tokens_per_rank=128 |
| 93 | + [7168, 3584, 6, 8192, 128], # num_max_dispatch_tokens_per_rank=128 |
| 94 | + [7168, 7168, 6, 10240, 128], # num_max_dispatch_tokens_per_rank=160 |
| 95 | + [7168, 3584, 6, 10240, 128], # num_max_dispatch_tokens_per_rank=160 |
| 96 | + [7168, 7168, 6, 12288, 128], # num_max_dispatch_tokens_per_rank=192 |
| 97 | + [7168, 3584, 6, 12288, 128], # num_max_dispatch_tokens_per_rank=192 |
| 98 | + [7168, 7168, 6, 16384, 128], # num_max_dispatch_tokens_per_rank=256 |
| 99 | + [7168, 3584, 6, 16384, 128], # num_max_dispatch_tokens_per_rank=256 |
| 100 | + [7168, 7168, 6, 20480, 128], # num_max_dispatch_tokens_per_rank=320 |
| 101 | + [7168, 3584, 6, 20480, 128], # num_max_dispatch_tokens_per_rank=320 |
| 102 | + [7168, 7168, 7, 8192, 128], # num_max_dispatch_tokens_per_rank=128 |
| 103 | + [7168, 3584, 7, 8192, 128], # num_max_dispatch_tokens_per_rank=128 |
| 104 | + [7168, 7168, 7, 10240, 128], # num_max_dispatch_tokens_per_rank=160 |
| 105 | + [7168, 3584, 7, 10240, 128], # num_max_dispatch_tokens_per_rank=160 |
| 106 | + [7168, 7168, 7, 12288, 128], # num_max_dispatch_tokens_per_rank=192 |
| 107 | + [7168, 3584, 7, 12288, 128], # num_max_dispatch_tokens_per_rank=192 |
| 108 | + [7168, 7168, 7, 16384, 128], # num_max_dispatch_tokens_per_rank=256 |
| 109 | + [7168, 3584, 7, 16384, 128], # num_max_dispatch_tokens_per_rank=256 |
| 110 | + [7168, 7168, 7, 20480, 128], # num_max_dispatch_tokens_per_rank=320 |
| 111 | + [7168, 3584, 7, 20480, 128], # num_max_dispatch_tokens_per_rank=320 |
| 112 | +] |
89 | 113 |
|
90 | 114 | dtype = ["BF16"] |
91 | 115 |
|
|
0 commit comments