Skip to content

Commit 05f4f81

Browse files
committed
refine the device
1 parent 201c960 commit 05f4f81

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,12 +230,12 @@ def test_dequantization_accuracy(self, config, device):
230230
@unittest.skipIf(
231231
not torch_version_at_least("2.7.0"), "torch 2.6.0 and below has custom fx pass"
232232
)
233-
@common_utils.parametrize("device", get_available_devices())
234-
def test_available_gpu_kernels(self, device):
233+
def test_available_gpu_kernels(self):
235234
"""Check which GPU kernels are used"""
236235
torch.compiler.reset()
237236

238237
M, K, N = 128, 256, 512
238+
device = get_current_accelerator_device()
239239
m = torch.nn.Sequential(
240240
torch.nn.Linear(K, N, device=device, dtype=torch.bfloat16)
241241
)
@@ -286,7 +286,7 @@ class TestInt8StaticQuant(TorchAOIntegrationTestCase):
286286
@common_utils.parametrize(
287287
"act_mapping_type", [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]
288288
)
289-
@common_utils.parametrize("device", get_current_accelerator_device)
289+
@common_utils.parametrize("device", get_available_devices())
290290
def test_static_activation_per_row_int8_weight(
291291
self, granularity, dtype, act_mapping_type, device
292292
):

0 commit comments

Comments
 (0)