Skip to content

Commit 02105d4

Browse files
[mxfp8 training] add cutedsl kernel for mxfp8 quantation along dim0 (#4156)
* [mxfp8 moe training] add cutedsl to quantize 2d tensor along dim0 * vectorized stores for scales * 5.1 tb/s * iter along k; 5.8-6.4 tb/s * add bench script * refactor to use shared cute utils * add docstrings, update variable names for clarity * update var names
1 parent d17c61b commit 02105d4

File tree

10 files changed

+1720
-200
lines changed

10 files changed

+1720
-200
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def run(
109109
print(f"triton version: {triton.__version__}")
110110
print(f"mode: {mode}")
111111
assert mode in (
112+
"memcpy",
112113
"dim0",
113114
"dim1",
114115
"dim0_dim1",
@@ -125,11 +126,31 @@ def run(
125126
"dim1_mxfp8_triton_rceil",
126127
"dim1_mxfp8_cuda_floor",
127128
"dim1_mxfp8_cuda_rceil",
129+
"dim0_mxfp8_cutedsl_2d_floor",
130+
"dim0_mxfp8_cutedsl_2d_rceil",
128131
)
129132

130133
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
131134

132-
if mode == "dim0":
135+
if mode == "memcpy":
136+
# Baseline memcpy benchmark to establish max achievable bandwidth
137+
y = torch.randn_like(x)
138+
139+
# Warmup
140+
for _ in range(2):
141+
y.copy_(x)
142+
143+
time_us = benchmark_cuda_function_in_microseconds(
144+
lambda src, dst: dst.copy_(src),
145+
x,
146+
y,
147+
)
148+
149+
# bytes_read + bytes_written
150+
bytes_rw = 2 * x.numel() * bytes_per_el_bf16
151+
bps = bytes_rw / (time_us / 1e6)
152+
153+
elif mode == "dim0":
133154
scale_dim0_reference_c = torch.compile(scale_dim0_reference)
134155
y_d0, s_d0 = scale_dim0_reference_c(x, BLOCK_SIZE)
135156

@@ -452,6 +473,54 @@ def run(
452473
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
453474
bps = (bytes_r + bytes_w) / (time_us / 1e6)
454475

476+
elif mode == "dim0_mxfp8_cutedsl_2d_floor":
477+
from torchao.prototype.moe_training.kernels.mxfp8 import mxfp8_quantize_cuda_2d
478+
479+
y_d0, s_d0 = mxfp8_quantize_cuda_2d(
480+
x, block_size=BLOCK_SIZE, scaling_mode="floor"
481+
)
482+
483+
for _ in range(2):
484+
__ = mxfp8_quantize_cuda_2d(x, block_size=BLOCK_SIZE, scaling_mode="floor")
485+
486+
time_us = benchmark_cuda_function_in_microseconds(
487+
lambda x: mxfp8_quantize_cuda_2d(
488+
x, block_size=BLOCK_SIZE, scaling_mode="floor"
489+
),
490+
x,
491+
)
492+
493+
assert y_d0.dtype == torch.float8_e4m3fn
494+
assert s_d0.dtype == torch.float8_e8m0fnu
495+
496+
bytes_r = x.numel() * bytes_per_el_bf16
497+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
498+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
499+
500+
elif mode == "dim0_mxfp8_cutedsl_2d_rceil":
501+
from torchao.prototype.moe_training.kernels.mxfp8 import mxfp8_quantize_cuda_2d
502+
503+
y_d0, s_d0 = mxfp8_quantize_cuda_2d(
504+
x, block_size=BLOCK_SIZE, scaling_mode="rceil"
505+
)
506+
507+
for _ in range(2):
508+
__ = mxfp8_quantize_cuda_2d(x, block_size=BLOCK_SIZE, scaling_mode="rceil")
509+
510+
time_us = benchmark_cuda_function_in_microseconds(
511+
lambda x: mxfp8_quantize_cuda_2d(
512+
x, block_size=BLOCK_SIZE, scaling_mode="rceil"
513+
),
514+
x,
515+
)
516+
517+
assert y_d0.dtype == torch.float8_e4m3fn
518+
assert s_d0.dtype == torch.float8_e8m0fnu
519+
520+
bytes_r = x.numel() * bytes_per_el_bf16
521+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
522+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
523+
455524
else:
456525
raise AssertionError(f"unknown mode {mode}")
457526

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from tqdm import tqdm
15+
16+
from benchmarks.utils import benchmark_cuda_function_in_microseconds
17+
from torchao.prototype.moe_training.kernels.mxfp8 import (
18+
mx_block_rearrange_2d_M_groups_cuda,
19+
)
20+
from torchao.prototype.moe_training.kernels.mxfp8.cutedsl_quantize_2d import (
21+
mxfp8_quantize_cutedsl_2d,
22+
)
23+
from torchao.prototype.moe_training.utils import generate_jagged_offs
24+
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0
25+
26+
device = torch.device("cuda")
27+
28+
# Needed since changing args to function causes recompiles
29+
torch._dynamo.config.cache_size_limit = 1000
30+
31+
32+
@dataclass(frozen=True)
33+
class ExperimentConfig:
34+
input_shape: tuple[int, int]
35+
scaling_mode: str
36+
num_groups: int
37+
38+
39+
@dataclass(frozen=True)
40+
class ExperimentResult:
41+
# time
42+
cutedsl_blocked_us: float
43+
triton_plus_rearrange_us: float
44+
# mem bw
45+
cutedsl_blocked_gbps: float
46+
triton_plus_rearrange_gbps: float
47+
48+
49+
@dataclass(frozen=True)
50+
class Experiment:
51+
config: ExperimentConfig
52+
result: ExperimentResult
53+
54+
55+
def get_configs() -> List[ExperimentConfig]:
56+
input_shapes = [
57+
# DeepSeekV3 671b shapes
58+
(8192, 2048),
59+
(8192, 7168),
60+
(32768, 2048),
61+
(32768, 7168),
62+
(131072, 2048),
63+
(131072, 7168),
64+
]
65+
scaling_modes = ["floor", "rceil"]
66+
num_groups_list = [8]
67+
configs = []
68+
for shape, scaling_mode, num_groups in itertools.product(
69+
input_shapes, scaling_modes, num_groups_list
70+
):
71+
configs.append(
72+
ExperimentConfig(
73+
input_shape=shape,
74+
scaling_mode=scaling_mode,
75+
num_groups=num_groups,
76+
)
77+
)
78+
return configs
79+
80+
81+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
82+
block_size = 32
83+
input_shape = config.input_shape
84+
scaling_mode = config.scaling_mode
85+
num_groups = config.num_groups
86+
87+
input_tensor = torch.randn(
88+
*input_shape,
89+
dtype=torch.bfloat16,
90+
device=device,
91+
)
92+
93+
M, K = input_shape
94+
95+
# Generate jagged offsets with multiples of 128
96+
# TODO: we use multiple of 128 here to avoid per-group padding requirement in blocked scales layout, which cutedsl doesn't support yet.
97+
group_end_offsets = generate_jagged_offs(
98+
num_groups, M, multiple_of=128, device=device
99+
)
100+
101+
# Benchmark 1: CuTeDSL kernel with blocked scale output
102+
data_cutedsl, scales_cutedsl = mxfp8_quantize_cutedsl_2d(
103+
input_tensor,
104+
block_size=block_size,
105+
scaling_mode=scaling_mode,
106+
blocked_scale_output=True,
107+
)
108+
cutedsl_blocked_time_us = benchmark_cuda_function_in_microseconds(
109+
mxfp8_quantize_cutedsl_2d,
110+
input_tensor,
111+
block_size=block_size,
112+
scaling_mode=scaling_mode,
113+
blocked_scale_output=True,
114+
)
115+
116+
# Benchmark 2: Triton quantization + CUDA scale rearrangement
117+
def triton_plus_rearrange(x, group_offs):
118+
# Quantize along dim0 (rowwise)
119+
data, scales = triton_to_mxfp8_dim0(
120+
x,
121+
inner_block_size=block_size,
122+
scaling_mode=scaling_mode,
123+
)
124+
# Convert scales to blocked layout
125+
scales_blocked = mx_block_rearrange_2d_M_groups_cuda(
126+
scales.view(torch.uint8), group_offs
127+
)
128+
return data, scales_blocked
129+
130+
data_triton, scales_triton = triton_plus_rearrange(input_tensor, group_end_offsets)
131+
triton_plus_rearrange_time_us = benchmark_cuda_function_in_microseconds(
132+
triton_plus_rearrange,
133+
input_tensor,
134+
group_end_offsets,
135+
)
136+
137+
# Memory bandwidth calculations
138+
bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8
139+
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
140+
bytes_per_scale_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
141+
142+
read_bytes = input_tensor.numel() * bytes_per_input_el
143+
write_bytes = (
144+
data_cutedsl.numel() * bytes_per_output_el
145+
+ scales_cutedsl.numel() * bytes_per_scale_el
146+
)
147+
148+
cutedsl_blocked_gbps = ((read_bytes + write_bytes) / 1e9) / (
149+
cutedsl_blocked_time_us / 1e6
150+
)
151+
triton_plus_rearrange_gbps = ((read_bytes + write_bytes) / 1e9) / (
152+
triton_plus_rearrange_time_us / 1e6
153+
)
154+
155+
return ExperimentResult(
156+
cutedsl_blocked_us=cutedsl_blocked_time_us,
157+
triton_plus_rearrange_us=triton_plus_rearrange_time_us,
158+
cutedsl_blocked_gbps=cutedsl_blocked_gbps,
159+
triton_plus_rearrange_gbps=triton_plus_rearrange_gbps,
160+
)
161+
162+
163+
def print_results(experiments: List[Experiment]):
164+
headers = [
165+
"input_shape",
166+
"scaling_mode",
167+
"num_groups",
168+
"cutedsl_blocked_us",
169+
"triton+rearrange_us",
170+
"speedup",
171+
"cutedsl_gbps",
172+
"triton+rearrange_gbps",
173+
]
174+
rows = []
175+
for experiment in experiments:
176+
speedup = (
177+
experiment.result.triton_plus_rearrange_us
178+
/ experiment.result.cutedsl_blocked_us
179+
)
180+
rows.append(
181+
[
182+
str(experiment.config.input_shape),
183+
experiment.config.scaling_mode,
184+
experiment.config.num_groups,
185+
f"{experiment.result.cutedsl_blocked_us:.2f}",
186+
f"{experiment.result.triton_plus_rearrange_us:.2f}",
187+
f"{speedup:.2f}x",
188+
f"{experiment.result.cutedsl_blocked_gbps:.1f}",
189+
f"{experiment.result.triton_plus_rearrange_gbps:.1f}",
190+
]
191+
)
192+
print(tabulate(rows, headers=headers))
193+
194+
195+
def main():
196+
torch.random.manual_seed(123)
197+
configs = get_configs()
198+
results = []
199+
for config in tqdm(configs):
200+
result = run_experiment(config)
201+
results.append(Experiment(config=config, result=result))
202+
203+
# Use Tabulate to print results
204+
print_results(results)
205+
206+
207+
if __name__ == "__main__":
208+
main()

test/prototype/moe_training/test_kernels.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def _is_sm_10x() -> bool:
3838
fused_pad_token_groups_cuda,
3939
fused_unpad_token_groups_cuda,
4040
mx_block_rearrange_2d_M_groups_cuda,
41+
mxfp8_quantize_cuda_2d,
4142
mxfp8_quantize_cuda_3d,
4243
torch_pad_token_groups,
4344
torch_to_blocked_2d_K_groups,
@@ -436,6 +437,60 @@ def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode):
436437
assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match"
437438

438439

440+
@pytest.mark.skipif(
441+
not _is_sm_10x(),
442+
reason="MXFP8 requires CUDA SM 10.x",
443+
)
444+
@pytest.mark.skipif(
445+
not _mxfp8_cutedsl_kernels_available,
446+
reason="MXFP8 cutedsl kernels not available",
447+
)
448+
@pytest.mark.parametrize("M", (32, 160, 8192))
449+
@pytest.mark.parametrize("K", (32, 96, 1536, 5120, 7168, 8192))
450+
@pytest.mark.parametrize("input_dtype", (torch.bfloat16,))
451+
@pytest.mark.parametrize(
452+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
453+
)
454+
def test_cuda_mx_dim0_2d_numerics(M, K, input_dtype, scaling_mode):
455+
scaling_mode_str = scaling_mode.value.lower()
456+
block_size = 32
457+
458+
# Use distinct incrementing values from 0 to M*K-1 to make debugging easier.
459+
x = (
460+
torch.arange(0, M * K, dtype=input_dtype, device="cuda")
461+
.reshape(M, K)
462+
.contiguous()
463+
)
464+
465+
# Reference implementation
466+
s_d0_ref, y_d0_ref = to_mx(
467+
x,
468+
elem_dtype=torch.float8_e4m3fn,
469+
block_size=block_size,
470+
scaling_mode=scaling_mode,
471+
)
472+
473+
# CuTeDSL kernel implementation
474+
y_d0, s_d0 = mxfp8_quantize_cuda_2d(
475+
x,
476+
block_size=block_size,
477+
scaling_mode=scaling_mode_str,
478+
)
479+
480+
# Convert blocked scales back to reference format
481+
s_d0 = from_blocked(s_d0, M, K // block_size).to(s_d0_ref.dtype)
482+
483+
# Check scales
484+
torch.testing.assert_close(s_d0, s_d0_ref, rtol=0, atol=0)
485+
486+
# Check quantized values
487+
torch.testing.assert_close(y_d0, y_d0_ref, rtol=0, atol=0)
488+
489+
# Verify row-major layout
490+
assert y_d0.stride() == (K, 1), "quantized tensor should be row-major"
491+
assert y_d0.stride() == y_d0_ref.stride(), "quantized tensor strides do not match"
492+
493+
439494
@pytest.mark.skipif(
440495
not _mxfp8_cuda_kernels_available,
441496
reason="CUDA kernel requires sm_100 and CUDA 12.8+",

torchao/prototype/moe_training/kernels/mxfp8/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
fused_pad_token_groups_cuda, # noqa: F401
44
fused_unpad_token_groups_cuda, # noqa: F401
55
mx_block_rearrange_2d_M_groups_cuda, # noqa: F401
6+
mxfp8_quantize_cuda_2d, # noqa: F401
67
mxfp8_quantize_cuda_3d, # noqa: F401
78
torch_pad_token_groups, # noqa: F401
89
torch_to_blocked_2d_K_groups, # noqa: F401

0 commit comments

Comments
 (0)