Skip to content

Commit 78170f6

Browse files
committed
Merge remote-tracking branch 'upstream/main' into u8s8
2 parents aa61c38 + 1087d59 commit 78170f6

File tree

138 files changed

+7029
-11405
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

138 files changed

+7029
-11405
lines changed

.github/workflows/dashboard_perf_test.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,6 @@ jobs:
4545
# llama3 - compile baseline
4646
${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json
4747
48-
# llama3 - autoquant
49-
${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --quantization autoquant --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json
50-
5148
# skipping SAM because of https://hud.pytorch.org/pr/pytorch/ao/1407
5249
# # SAM
5350
# ${CONDA_RUN} pip install git+https://github.com/pytorch-labs/segment-anything-fast.git@main

benchmarks/microbenchmarks/utils.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from torchao.quantization import (
1717
Float8DynamicActivationFloat8WeightConfig,
1818
Float8WeightOnlyConfig,
19-
GemliteUIntXWeightOnlyConfig,
2019
Int8DynamicActivationInt8WeightConfig,
2120
Int8WeightOnlyConfig,
2221
MappingType,
@@ -182,11 +181,7 @@ def string_to_config(
182181
if "int8wo" in quantization:
183182
return Int8WeightOnlyConfig()
184183
if "int8dq" in quantization:
185-
if sparsity is not None and ("semi" in sparsity or "2:4" in sparsity):
186-
from torchao.dtypes import SemiSparseLayout
187-
188-
return Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout())
189-
elif "int8dq_prefill_wo_decode" in quantization:
184+
if "int8dq_prefill_wo_decode" in quantization:
190185
return Int8DynamicActivationInt8WeightConfig(weight_only_decode=True)
191186
else:
192187
return Int8DynamicActivationInt8WeightConfig()
@@ -225,23 +220,6 @@ def string_to_config(
225220
else:
226221
granularity = PerTensor()
227222
return Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
228-
if "gemlitewo" in quantization:
229-
params = quantization.split("-")
230-
bit_width = int(params[1]) if len(params) > 1 else 4
231-
group_size = (
232-
int(params[2])
233-
if len(params) > 2 and bit_width == 4
234-
else None
235-
if bit_width == 8
236-
else 64
237-
)
238-
assert group_size in [
239-
32,
240-
64,
241-
128,
242-
256,
243-
], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
244-
return GemliteUIntXWeightOnlyConfig(group_size=group_size, bit_width=bit_width)
245223
return None
246224

247225

benchmarks/mx_formats/cast_bench.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def run(
109109
print(f"triton version: {triton.__version__}")
110110
print(f"mode: {mode}")
111111
assert mode in (
112+
"memcpy",
112113
"dim0",
113114
"dim1",
114115
"dim0_dim1",
@@ -125,11 +126,31 @@ def run(
125126
"dim1_mxfp8_triton_rceil",
126127
"dim1_mxfp8_cuda_floor",
127128
"dim1_mxfp8_cuda_rceil",
129+
"dim0_mxfp8_cutedsl_2d_floor",
130+
"dim0_mxfp8_cutedsl_2d_rceil",
128131
)
129132

130133
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
131134

132-
if mode == "dim0":
135+
if mode == "memcpy":
136+
# Baseline memcpy benchmark to establish max achievable bandwidth
137+
y = torch.randn_like(x)
138+
139+
# Warmup
140+
for _ in range(2):
141+
y.copy_(x)
142+
143+
time_us = benchmark_cuda_function_in_microseconds(
144+
lambda src, dst: dst.copy_(src),
145+
x,
146+
y,
147+
)
148+
149+
# bytes_read + bytes_written
150+
bytes_rw = 2 * x.numel() * bytes_per_el_bf16
151+
bps = bytes_rw / (time_us / 1e6)
152+
153+
elif mode == "dim0":
133154
scale_dim0_reference_c = torch.compile(scale_dim0_reference)
134155
y_d0, s_d0 = scale_dim0_reference_c(x, BLOCK_SIZE)
135156

@@ -452,6 +473,54 @@ def run(
452473
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
453474
bps = (bytes_r + bytes_w) / (time_us / 1e6)
454475

476+
elif mode == "dim0_mxfp8_cutedsl_2d_floor":
477+
from torchao.prototype.moe_training.kernels.mxfp8 import mxfp8_quantize_cuda_2d
478+
479+
y_d0, s_d0 = mxfp8_quantize_cuda_2d(
480+
x, block_size=BLOCK_SIZE, scaling_mode="floor"
481+
)
482+
483+
for _ in range(2):
484+
__ = mxfp8_quantize_cuda_2d(x, block_size=BLOCK_SIZE, scaling_mode="floor")
485+
486+
time_us = benchmark_cuda_function_in_microseconds(
487+
lambda x: mxfp8_quantize_cuda_2d(
488+
x, block_size=BLOCK_SIZE, scaling_mode="floor"
489+
),
490+
x,
491+
)
492+
493+
assert y_d0.dtype == torch.float8_e4m3fn
494+
assert s_d0.dtype == torch.float8_e8m0fnu
495+
496+
bytes_r = x.numel() * bytes_per_el_bf16
497+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
498+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
499+
500+
elif mode == "dim0_mxfp8_cutedsl_2d_rceil":
501+
from torchao.prototype.moe_training.kernels.mxfp8 import mxfp8_quantize_cuda_2d
502+
503+
y_d0, s_d0 = mxfp8_quantize_cuda_2d(
504+
x, block_size=BLOCK_SIZE, scaling_mode="rceil"
505+
)
506+
507+
for _ in range(2):
508+
__ = mxfp8_quantize_cuda_2d(x, block_size=BLOCK_SIZE, scaling_mode="rceil")
509+
510+
time_us = benchmark_cuda_function_in_microseconds(
511+
lambda x: mxfp8_quantize_cuda_2d(
512+
x, block_size=BLOCK_SIZE, scaling_mode="rceil"
513+
),
514+
x,
515+
)
516+
517+
assert y_d0.dtype == torch.float8_e4m3fn
518+
assert s_d0.dtype == torch.float8_e8m0fnu
519+
520+
bytes_r = x.numel() * bytes_per_el_bf16
521+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
522+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
523+
455524
else:
456525
raise AssertionError(f"unknown mode {mode}")
457526

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Blockwise FP8 Training Benchmarks
2+
3+
This directory contains benchmarking scripts for the blockwise FP8 quantization
4+
and GEMM paths under `torchao.prototype.blockwise_fp8_training.kernels`.
5+
6+
## Quantized Kernel Bandwidth Benchmark
7+
8+
The kernel-path bandwidth utility is:
9+
10+
```bash
11+
python -m benchmarks.prototype.blockwise_fp8_training.benchmark_quant_kernel_bandwidth
12+
```
13+
14+
To additionally validate Triton outputs against the Torch reference
15+
implementations:
16+
17+
```bash
18+
python -m benchmarks.prototype.blockwise_fp8_training.benchmark_quant_kernel_bandwidth --check-correctness
19+
```
20+
21+
What it reports:
22+
23+
- `kernel_us`: measured runtime of the public quantization wrapper call
24+
- `effective_logical_io_gbps`: logical tensor IO bytes divided by measured time
25+
- `logical_io_vs_achievable_%`: `effective_logical_io_gbps / achievable_bandwidth_gbps`
26+
27+
Notes:
28+
29+
- The benchmark times the public wrapper functions in
30+
`torchao.prototype.blockwise_fp8_training.kernels`.
31+
- `--check-correctness` runs the matching Torch reference path once per valid
32+
kernel and shape before reporting results. This adds overhead and is intended
33+
for validation, not headline timing runs.
34+
- The bandwidth number uses the expected tensor IO footprint, not hardware DRAM
35+
counters.
36+
- Peak bandwidth defaults to CUDA device properties. `--use-roofline-utils`
37+
switches to the static `roofline_utils` table.
38+
39+
### Methodology
40+
41+
- It times the public wrapper call, matching the style of the other benchmark
42+
scripts in this directory.
43+
- It uses CUDA event timing and the median, via
44+
`benchmark_cuda_function_in_microseconds(...)` from
45+
[benchmarks/utils.py](/home/dev/ao/benchmarks/utils.py#L101).
46+
- It validates unsupported shapes up front and skips them instead of silently
47+
measuring invalid configurations.
48+
49+
## Current H100 Results
50+
51+
Captured on 2026-03-20 with:
52+
53+
```bash
54+
python -m benchmarks.prototype.blockwise_fp8_training.benchmark_quant_kernel_bandwidth
55+
```
56+
57+
Environment:
58+
59+
- GPU: `NVIDIA H100 80GB HBM3`
60+
- Peak bandwidth reference: `3352.3 GB/s`
61+
- Peak bandwidth source: `cuda_device_properties`
62+
- Achievable bandwidth reference: `3084.1 GB/s`
63+
- Achievable bandwidth uses `92.0%` of peak bandwidth
64+
- Achievable bandwidth source: `roofline_utils_pct_achievable_mem_bw`
65+
66+
### Per-shape Results
67+
Tested with shapes 32768 and 131072 to reflect real world training:
68+
69+
| kernel | shape | kernel_us | effective_logical_io_gbps | logical_io_vs_achievable_% |
70+
|---|---|---:|---:|---:|
71+
| act_quant_transposed_lhs | 32768x4096 | 154.46 | 2633.9 | 85.4 |
72+
| weight_quant_transposed_rhs | 32768x4096 | 150.53 | 2675.2 | 86.7 |
73+
| act_quant_lhs | 32768x4096 | 150.86 | 2696.8 | 87.4 |
74+
| act_quant_rhs | 32768x4096 | 148.70 | 2736.0 | 88.7 |
75+
| weight_quant_rhs | 32768x4096 | 144.99 | 2777.3 | 90.1 |
76+
| weight_quant_transposed_rhs | 131072x4096 | 581.89 | 2768.1 | 89.8 |
77+
| act_quant_lhs | 131072x4096 | 586.98 | 2772.5 | 89.9 |
78+
| act_quant_transposed_lhs | 131072x4096 | 581.47 | 2798.7 | 90.7 |
79+
| act_quant_rhs | 131072x4096 | 562.56 | 2892.8 | 93.8 |
80+
| weight_quant_rhs | 131072x4096 | 555.30 | 2900.7 | 94.1 |

0 commit comments

Comments
 (0)