Skip to content

Commit 96a9cdf

Browse files
[mxfp8 moe training] remove unused block_size arg (#4177)
1 parent efbcb0e commit 96a9cdf

File tree

6 files changed

+13
-29
lines changed

6 files changed

+13
-29
lines changed

benchmarks/prototype/moe_training/mxfp8/bench_ep_pipeline.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ def mxfp8_pipeline(
226226
mx_permuted,
227227
expert_weights_t,
228228
offs=mx_group_offsets,
229-
block_size=block_size,
230229
wgrad_with_hp=True,
231230
)
232231

benchmarks/prototype/moe_training/mxfp8/roofline_unified.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def wrapper():
436436
return time_ms
437437

438438

439-
def benchmark_mxfp8_grouped_mm_fwd_bwd(x, w_t, offs, labels, block_size=32):
439+
def benchmark_mxfp8_grouped_mm_fwd_bwd(x, w_t, offs, labels):
440440
"""Benchmark _to_mxfp8_then_scaled_grouped_mm forward + backward"""
441441
x_clone = x.clone().requires_grad_(True)
442442
w_t_clone = w_t.clone().requires_grad_(True)
@@ -447,7 +447,6 @@ def benchmark_mxfp8_grouped_mm_fwd_bwd(x, w_t, offs, labels, block_size=32):
447447
A = x_clone
448448
B_t = w_t_clone
449449
offs_arg = offs
450-
block_size_arg = block_size
451450
out_dtype = torch.bfloat16
452451
kernel_preference = KernelPreference.AUTO
453452
wgrad_with_hp = False
@@ -458,7 +457,6 @@ def wrapper():
458457
A,
459458
B_t,
460459
offs_arg,
461-
block_size_arg,
462460
out_dtype,
463461
kernel_preference,
464462
wgrad_with_hp,

test/prototype/moe_training/ep/test_compile.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def standard_pipeline(
8686
permuted,
8787
expert_weights_t,
8888
offs=offsets,
89-
block_size=block_size,
9089
wgrad_with_hp=True,
9190
)
9291

@@ -154,7 +153,6 @@ def mxfp8_pipeline(
154153
mx_permuted,
155154
expert_weights_t,
156155
offs=mx_group_offsets,
157-
block_size=block_size,
158156
wgrad_with_hp=True,
159157
)
160158

test/prototype/moe_training/ep/test_integration.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,6 @@ def test_full_pipeline(self):
230230
mx_permuted,
231231
expert_weights.transpose(-2, -1),
232232
offs=mx_group_offsets,
233-
block_size=block_size,
234233
# wgrad_with_hp must be true if inputs are pre-quantized (MXTensor)
235234
wgrad_with_hp=True,
236235
)

test/prototype/moe_training/test_mxfp8_grouped_mm.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
165165
"torch native dynamic per group pad/unpad functions do not work with torch.compile yet: https://github.com/pytorch/pytorch/issues/176770"
166166
)
167167

168-
block_size = 32
169168
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
170169
w = torch.randn(
171170
num_experts,
@@ -194,7 +193,6 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
194193
x,
195194
w_t,
196195
offs=offs,
197-
block_size=block_size,
198196
kernel_preference=kernel_preference,
199197
wgrad_with_hp=wgrad_with_hp,
200198
scale_calculation_mode=scale_mode,
@@ -262,7 +260,6 @@ def test_mxfp8_grouped_gemm_from_qdata_and_scales_matches_dynamic():
262260
x_mx,
263261
w_t,
264262
offs=offs,
265-
block_size=block_size,
266263
out_dtype=torch.bfloat16,
267264
kernel_preference=KernelPreference.EMULATED,
268265
wgrad_with_hp=True,
@@ -272,7 +269,6 @@ def test_mxfp8_grouped_gemm_from_qdata_and_scales_matches_dynamic():
272269
x_ref,
273270
w_t_ref,
274271
offs=offs,
275-
block_size=block_size,
276272
out_dtype=torch.bfloat16,
277273
kernel_preference=KernelPreference.EMULATED,
278274
wgrad_with_hp=True,
@@ -334,7 +330,6 @@ def test_mxfp8_grouped_gemm_from_qdata_and_scales_forward():
334330
x_mx,
335331
w_t,
336332
offs=offs,
337-
block_size=block_size,
338333
out_dtype=torch.bfloat16,
339334
kernel_preference=KernelPreference.EMULATED,
340335
wgrad_with_hp=True,
@@ -344,7 +339,6 @@ def test_mxfp8_grouped_gemm_from_qdata_and_scales_forward():
344339
x,
345340
w_t,
346341
offs=offs,
347-
block_size=block_size,
348342
out_dtype=torch.bfloat16,
349343
kernel_preference=KernelPreference.EMULATED,
350344
wgrad_with_hp=True,
@@ -392,7 +386,6 @@ def test_mxfp8_grouped_gemm_mxtensor_requires_wgrad_with_hp():
392386
x_mx,
393387
w_t,
394388
offs=offs,
395-
block_size=block_size,
396389
out_dtype=torch.bfloat16,
397390
kernel_preference=KernelPreference.EMULATED,
398391
wgrad_with_hp=False,

torchao/prototype/moe_training/mxfp8_grouped_mm.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def _to_mxfp8_then_scaled_grouped_mm(
8383
A: torch.Tensor,
8484
B_t: torch.Tensor,
8585
offs: Optional[torch.Tensor] = None,
86-
block_size: Optional[int] = None,
8786
out_dtype: Optional[torch.dtype] = torch.bfloat16,
8887
kernel_preference: KernelPreference = KernelPreference.AUTO,
8988
wgrad_with_hp: bool = False,
@@ -103,7 +102,6 @@ def _to_mxfp8_then_scaled_grouped_mm(
103102
which must be 3D, which must be shape (G, K, N)
104103
and in "per group column-major memory" layout (i.e., strides of (N*K, 1, N)).
105104
offs (int32 torch.Tensor): The offsets to use to mark the end index of each group along the dim0 of the A tensor.
106-
block_size (int): Block size for MXFP8 quantization. Must be 32 (the only supported value). This parameter exists for backward compatibility but is ignored.
107105
out_dtype (torch.dtype): Output dtype for the result. Defaults to torch.bfloat16.
108106
kernel_preference (KernelPreference): Kernel preference (AUTO uses CUDA/Triton, EMULATED uses to_mx). Defaults to KernelPreference.AUTO.
109107
wgrad_with_hp (bool): Whether to compute weight gradient in high precision. Defaults to False.
@@ -120,7 +118,6 @@ def _to_mxfp8_then_scaled_grouped_mm(
120118
A,
121119
B_t,
122120
offs,
123-
block_size,
124121
out_dtype,
125122
kernel_preference,
126123
wgrad_with_hp,
@@ -144,7 +141,6 @@ def forward(
144141
input_act: torch.Tensor,
145142
weight_t: torch.Tensor,
146143
group_end_offsets: Optional[torch.Tensor] = None,
147-
block_size: int = 32,
148144
out_dtype: Optional[torch.dtype] = torch.bfloat16,
149145
kernel_preference: KernelPreference = KernelPreference.AUTO,
150146
wgrad_with_hp: bool = False,
@@ -158,15 +154,18 @@ def forward(
158154
input_act: Input activations, shape (M, K) - may be MXTensor or high-precision
159155
weight_t: Expert weights transposed, shape (E, K, N) - always high-precision
160156
group_end_offsets: End index of each token group, shape (E,)
161-
block_size: Block size for MXFP8 quantization (must be 32)
162157
out_dtype: Output dtype (bfloat16 or float32)
163158
kernel_preference: Kernel preference (AUTO uses CUDA/Triton, EMULATED uses to_mx)
164159
wgrad_with_hp: Compute weight gradient in high precision
165160
scale_calculation_mode: Mode for scale calculation (RCEIL, FLOOR, etc.)
161+
pad_token_groups_for_grouped_mm: Whether to pad token groups to the next multiple of 32
166162
167163
Returns:
168164
Output tensor, shape (M, N)
169165
"""
166+
# block_size is always 32 for MXFP8
167+
block_size = 32
168+
170169
assert kernel_preference in (
171170
KernelPreference.AUTO,
172171
KernelPreference.EMULATED,
@@ -182,7 +181,6 @@ def forward(
182181
# Input validation
183182
assert input_act.ndim == 2, "input_act must be 2D"
184183
assert weight_t.ndim == 3, "weight_t must be 3D"
185-
assert block_size == 32, "Only block_size=32 is supported"
186184
assert group_end_offsets is not None, (
187185
"group_end_offsets must be provided for 2d-3d grouped mm"
188186
)
@@ -247,7 +245,6 @@ def forward(
247245
padded_group_start_offsets,
248246
padded_group_end_offsets,
249247
)
250-
ctx.block_size = block_size
251248
ctx.out_dtype = out_dtype
252249
ctx.kernel_preference = kernel_preference
253250
ctx.wgrad_with_hp = wgrad_with_hp
@@ -279,7 +276,8 @@ def backward(ctx, grad_output: torch.Tensor):
279276
padded_group_end_offsets,
280277
) = ctx.saved_tensors
281278

282-
block_size = ctx.block_size
279+
# block_size is always 32 for MXFP8
280+
block_size = 32
283281
out_dtype = ctx.out_dtype
284282
kernel_preference = ctx.kernel_preference
285283
wgrad_with_hp = ctx.wgrad_with_hp
@@ -338,13 +336,12 @@ def backward(ctx, grad_output: torch.Tensor):
338336
return (
339337
grad_input,
340338
grad_weight_t,
341-
None,
342-
None,
343-
None,
344-
None,
345-
None,
346-
None,
347-
None,
339+
None, # group_end_offsets
340+
None, # out_dtype
341+
None, # kernel_preference
342+
None, # wgrad_with_hp
343+
None, # scale_calculation_mode
344+
None, # pad_token_groups_for_grouped_mm
348345
)
349346

350347

0 commit comments

Comments
 (0)