Skip to content

Commit d17c61b

Browse files
clean up unused rocm references in test_training.py (#4170)
1 parent 136cacb commit d17c61b

File tree

1 file changed

+5
-17
lines changed

1 file changed

+5
-17
lines changed

test/prototype/moe_training/test_training.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
)
3232
from torchao.quantization.quant_api import quantize_
3333
from torchao.quantization.quantize_.common import KernelPreference
34-
from torchao.utils import is_MI300, is_MI350, is_ROCM
3534

3635
# Reference MoE implementation (copied from torchtitan to avoid external dependency)
3736
from .reference_moe import MoE, MoEArgs, set_token_group_alignment_size_m
@@ -102,34 +101,23 @@ def test_moe_training(
102101
)
103102
assert torch.cuda.is_available()
104103

105-
# Per-group padding has known shape mismatch issues with experts on ROCm
106-
# (introduced in #3998). Skip until resolved.
107-
if is_ROCM() and "experts" in target_fqns:
108-
pytest.skip(
109-
"MoE expert training has known shape mismatch on ROCm (per-group padding, see #3998)"
110-
)
111-
112104
# Emulated mode with compile is not supported
113105
if recipe == MXFP8TrainingRecipe.MXFP8_EMULATED_RCEIL and compile:
114106
pytest.skip(
115107
"Skipping compile=True with kernel_preference=EMULATED, not currently supported"
116108
)
117109

118-
# FP8_ROWWISE hardware path requires SM90 (CUDA) or MI300/MI350 (ROCm)
110+
# FP8_ROWWISE hardware path requires SM90
119111
if recipe == Float8TrainingRecipe.FP8_ROWWISE:
120112
if compile:
121113
pytest.skip(
122114
"https://github.com/pytorch/ao/issues/4048: 'FakeTensor' object has no attribute '__tensor_flatten__'"
123115
)
124116

125-
if is_ROCM():
126-
if not (is_MI300() or is_MI350()):
127-
pytest.skip("FP8 rowwise test requires MI300 or MI350 on ROCm")
128-
else:
129-
if torch.cuda.get_device_capability() != (9, 0):
130-
pytest.skip(
131-
f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
132-
)
117+
if torch.cuda.get_device_capability() != (9, 0):
118+
pytest.skip(
119+
f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
120+
)
133121
if not token_groups_aligned:
134122
pytest.skip("FP8 rowwise doesn't support per group token padding yet")
135123

0 commit comments

Comments
 (0)