|
30 | 30 | from torchao.testing.model_architectures import ToyTwoLinearModel |
31 | 31 | from torchao.testing.utils import TorchAOIntegrationTestCase |
32 | 32 | from torchao.utils import ( |
| 33 | + get_available_devices, |
33 | 34 | get_current_accelerator_device, |
34 | 35 | torch_version_at_least, |
35 | | - get_available_devices, |
36 | 36 | ) |
37 | 37 |
|
38 | 38 | INT8_TEST_CONFIGS = [ |
@@ -267,15 +267,10 @@ class TestInt8StaticQuant(TorchAOIntegrationTestCase): |
267 | 267 | def test_static_activation_per_row_int8_weight(self, granularity, dtype): |
268 | 268 | torch.compiler.reset() |
269 | 269 |
|
270 | | -<<<<<<< HEAD |
271 | 270 | M, N, K = 128, 128, 128 |
272 | | - input_tensor = torch.randn(M, K, dtype=dtype, device="cuda") |
273 | | -======= |
274 | | - M, N, K = 32, 32, 32 |
275 | 271 |
|
276 | 272 | _DEVICE = get_current_accelerator_device() |
277 | 273 | input_tensor = torch.randn(M, K, dtype=dtype, device=_DEVICE) |
278 | | ->>>>>>> f07387cd2 (refine the device) |
279 | 274 |
|
280 | 275 | model = torch.nn.Linear(K, N, bias=False).eval().to(device=_DEVICE, dtype=dtype) |
281 | 276 | model_static_quant = copy.deepcopy(model) |
@@ -353,7 +348,7 @@ def test_static_act_quant_slice_and_select(self, granularity): |
353 | 348 | N, K = 256, 512 |
354 | 349 | M = 32 # batch size |
355 | 350 | dtype = torch.bfloat16 |
356 | | - device = "cuda" |
| 351 | + device = get_current_accelerator_device() |
357 | 352 |
|
358 | 353 | linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device) |
359 | 354 | input_tensor = torch.randn(M, K, dtype=dtype, device=device) |
@@ -414,11 +409,12 @@ def test_int8_weight_only_v2_correct_eps(self, dtype): |
414 | 409 | torch.manual_seed(42) |
415 | 410 |
|
416 | 411 | # Create test model |
417 | | - model = ToyTwoLinearModel(256, 128, 256, dtype=dtype, device="cuda").eval() |
| 412 | + _DEVICE = get_current_accelerator_device() |
| 413 | + model = ToyTwoLinearModel(256, 128, 256, dtype=dtype, device=_DEVICE).eval() |
418 | 414 | model_baseline = copy.deepcopy(model) |
419 | 415 |
|
420 | 416 | # Create input |
421 | | - input_tensor = torch.randn(32, 256, dtype=dtype, device="cuda") |
| 417 | + input_tensor = torch.randn(32, 256, dtype=dtype, device=_DEVICE) |
422 | 418 |
|
423 | 419 | # Get baseline output |
424 | 420 | output_baseline = model_baseline(input_tensor) |
|
0 commit comments