Skip to content

Commit e23dc55

Browse files
committed
refine the device
1 parent 93d4738 commit e23dc55

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
from torchao.testing.model_architectures import ToyTwoLinearModel
3131
from torchao.testing.utils import TorchAOIntegrationTestCase
3232
from torchao.utils import (
33+
get_available_devices,
3334
get_current_accelerator_device,
3435
torch_version_at_least,
35-
get_available_devices,
3636
)
3737

3838
INT8_TEST_CONFIGS = [
@@ -267,15 +267,10 @@ class TestInt8StaticQuant(TorchAOIntegrationTestCase):
267267
def test_static_activation_per_row_int8_weight(self, granularity, dtype):
268268
torch.compiler.reset()
269269

270-
<<<<<<< HEAD
271270
M, N, K = 128, 128, 128
272-
input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")
273-
=======
274-
M, N, K = 32, 32, 32
275271

276272
_DEVICE = get_current_accelerator_device()
277273
input_tensor = torch.randn(M, K, dtype=dtype, device=_DEVICE)
278-
>>>>>>> f07387cd2 (refine the device)
279274

280275
model = torch.nn.Linear(K, N, bias=False).eval().to(device=_DEVICE, dtype=dtype)
281276
model_static_quant = copy.deepcopy(model)
@@ -353,7 +348,7 @@ def test_static_act_quant_slice_and_select(self, granularity):
353348
N, K = 256, 512
354349
M = 32 # batch size
355350
dtype = torch.bfloat16
356-
device = "cuda"
351+
device = get_current_accelerator_device()
357352

358353
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device)
359354
input_tensor = torch.randn(M, K, dtype=dtype, device=device)
@@ -414,11 +409,12 @@ def test_int8_weight_only_v2_correct_eps(self, dtype):
414409
torch.manual_seed(42)
415410

416411
# Create test model
417-
model = ToyTwoLinearModel(256, 128, 256, dtype=dtype, device="cuda").eval()
412+
_DEVICE = get_current_accelerator_device()
413+
model = ToyTwoLinearModel(256, 128, 256, dtype=dtype, device=_DEVICE).eval()
418414
model_baseline = copy.deepcopy(model)
419415

420416
# Create input
421-
input_tensor = torch.randn(32, 256, dtype=dtype, device="cuda")
417+
input_tensor = torch.randn(32, 256, dtype=dtype, device=_DEVICE)
422418

423419
# Get baseline output
424420
output_baseline = model_baseline(input_tensor)

0 commit comments

Comments
 (0)