Skip to content

Commit 1402ea3

Browse files
committed
update
1 parent ad243e6 commit 1402ea3

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from torchao.testing.utils import TorchAOIntegrationTestCase, skip_if_xpu
3232
from torchao.utils import (
3333
get_available_devices,
34-
get_current_accelerator_device,
3534
torch_version_at_least,
3635
)
3736

@@ -227,7 +226,6 @@ def test_dequantization_accuracy(self, config, device):
227226
f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, weight_fp)}"
228227
)
229228

230-
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
231229
@unittest.skipIf(
232230
not torch_version_at_least("2.7.0"), "torch 2.6.0 and below has custom fx pass"
233231
)
@@ -236,7 +234,7 @@ def test_available_gpu_kernels(self):
236234
torch.compiler.reset()
237235

238236
M, K, N = 128, 256, 512
239-
device = get_available_devices()
237+
device = "xpu" if torch.xpu.is_available() else "cuda"
240238
m = torch.nn.Sequential(
241239
torch.nn.Linear(K, N, device=device, dtype=torch.bfloat16)
242240
)

0 commit comments

Comments
 (0)