|
31 | 31 | ) |
32 | 32 | from torchao.quantization.quant_api import quantize_ |
33 | 33 | from torchao.quantization.quantize_.common import KernelPreference |
34 | | -from torchao.utils import is_MI300, is_MI350, is_ROCM |
35 | 34 |
|
36 | 35 | # Reference MoE implementation (copied from torchtitan to avoid external dependency) |
37 | 36 | from .reference_moe import MoE, MoEArgs, set_token_group_alignment_size_m |
@@ -102,34 +101,23 @@ def test_moe_training( |
102 | 101 | ) |
103 | 102 | assert torch.cuda.is_available() |
104 | 103 |
|
105 | | - # Per-group padding has known shape mismatch issues with experts on ROCm |
106 | | - # (introduced in #3998). Skip until resolved. |
107 | | - if is_ROCM() and "experts" in target_fqns: |
108 | | - pytest.skip( |
109 | | - "MoE expert training has known shape mismatch on ROCm (per-group padding, see #3998)" |
110 | | - ) |
111 | | - |
112 | 104 | # Emulated mode with compile is not supported |
113 | 105 | if recipe == MXFP8TrainingRecipe.MXFP8_EMULATED_RCEIL and compile: |
114 | 106 | pytest.skip( |
115 | 107 | "Skipping compile=True with kernel_preference=EMULATED, not currently supported" |
116 | 108 | ) |
117 | 109 |
|
118 | | - # FP8_ROWWISE hardware path requires SM90 (CUDA) or MI300/MI350 (ROCm) |
| 110 | + # FP8_ROWWISE hardware path requires SM90 |
119 | 111 | if recipe == Float8TrainingRecipe.FP8_ROWWISE: |
120 | 112 | if compile: |
121 | 113 | pytest.skip( |
122 | 114 | "https://github.com/pytorch/ao/issues/4048: 'FakeTensor' object has no attribute '__tensor_flatten__'" |
123 | 115 | ) |
124 | 116 |
|
125 | | - if is_ROCM(): |
126 | | - if not (is_MI300() or is_MI350()): |
127 | | - pytest.skip("FP8 rowwise test requires MI300 or MI350 on ROCm") |
128 | | - else: |
129 | | - if torch.cuda.get_device_capability() != (9, 0): |
130 | | - pytest.skip( |
131 | | - f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}" |
132 | | - ) |
| 117 | + if torch.cuda.get_device_capability() != (9, 0): |
| 118 | + pytest.skip( |
| 119 | + f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}" |
| 120 | + ) |
133 | 121 | if not token_groups_aligned: |
134 | 122 | pytest.skip("FP8 rowwise doesn't support per group token padding yet") |
135 | 123 |
|
|
0 commit comments