Skip to content

Commit 6d4c02a

Browse files
committed
move _DEVICE into the test function which is guarded by @unittest.skipIf(not torch.accelerator.is_available()
1 parent 1d0766a commit 6d4c02a

File tree

13 files changed

+120
-79
lines changed

13 files changed

+120
-79
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,10 @@
4545
is_cusparselt_available = (
4646
hasattr(torch.backends, "cusparselt") and torch.backends.cusparselt.is_available()
4747
)
48-
_DEVICE = get_current_accelerator_device()
4948

5049

5150
def get_quantization_functions(
52-
do_sparse: bool, do_int4: bool, device: str = _DEVICE, int4_zp_int: bool = False
51+
do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False
5352
):
5453
base_functions = [
5554
Int8WeightOnlyConfig(),
@@ -85,10 +84,12 @@ def get_quantization_functions(
8584
return base_functions
8685

8786

87+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
8888
class TestAffineQuantized(TestCase):
8989
GPU_DEVICES = (["cuda"] if torch.cuda.is_available() else []) + (
9090
["xpu"] if torch.xpu.is_available() else []
9191
)
92+
_DEVICE = get_current_accelerator_device()
9293

9394
@unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available")
9495
def test_weights_only(self):
@@ -110,7 +111,9 @@ def test_weights_only(self):
110111
_ = torch.load(f, weights_only=True)
111112

112113
@unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available")
113-
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
114+
@common_utils.parametrize(
115+
"apply_quant", get_quantization_functions(False, False, _DEVICE)
116+
)
114117
def test_to_device(self, apply_quant):
115118
for device in self.GPU_DEVICES:
116119

@@ -171,6 +174,7 @@ def apply_uint6_weight_only_quant(linear):
171174
)
172175
return linear
173176

177+
_DEVICE = get_current_accelerator_device()
174178
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
175179
apply_uint6_weight_only_quant(linear)
176180

@@ -202,6 +206,7 @@ def test_print_quantized_module(self):
202206
"apply_quant", get_quantization_functions(False, True, _DEVICE, False)
203207
)
204208
def test_test_copy__apply(self, apply_quant):
209+
_DEVICE = get_current_accelerator_device()
205210
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
206211
linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
207212

@@ -226,6 +231,7 @@ def test_test_copy__apply(self, apply_quant):
226231
"apply_quant", get_quantization_functions(False, True, _DEVICE, False)
227232
)
228233
def test_copy__mismatch_metadata(self, apply_quant):
234+
_DEVICE = get_current_accelerator_device()
229235
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
230236
linear2 = torch.nn.Linear(128, 512, dtype=torch.bfloat16, device=_DEVICE)
231237

@@ -301,9 +307,8 @@ def test_alias(self, device, dtype):
301307
quantize_(dummy, Int8DynamicActivationInt8WeightConfig())
302308
_ = dummy.weight[...]
303309

304-
@common_utils.parametrize("device", [_DEVICE])
310+
@common_utils.parametrize("device", ["cuda"])
305311
@common_utils.parametrize("dtype", [torch.float16, torch.bfloat16])
306-
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
307312
@skip_if_no_gemlite()
308313
def test_slice_gemlite(self, device, dtype):
309314
# in_feature not divisible by 1024
@@ -384,7 +389,7 @@ def dequant(input_layer, in_features, orig_shape):
384389
)
385390
self.assertEqual((W_slice_ref - W_slice).abs().mean().item(), 0)
386391

387-
@common_utils.parametrize("device", [_DEVICE])
392+
@common_utils.parametrize("device", ["cuda"])
388393
@common_utils.parametrize("dtype", [torch.bfloat16])
389394
def test_matmul(self, device, dtype):
390395
x = torch.randn(53, 2048)

test/dtypes/test_affine_quantized_float.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
random.seed(0)
3838
torch.manual_seed(0)
39-
_DEVICE = get_current_accelerator_device()
4039

4140

4241
class ToyLinearModel(torch.nn.Module):
@@ -53,15 +52,15 @@ def forward(self, x):
5352

5453
class TestAffineQuantizedFloat8Compile(InductorTestCase):
5554
@unittest.skipIf(
56-
_DEVICE == "cuda" and not is_sm_at_least_89(),
55+
torch.cuda.is_available() and not is_sm_at_least_89(),
5756
"Requires GPU with compute capability >= 8.9",
5857
)
5958
def test_invalid_granularity(self):
6059
with pytest.raises(ValueError, match="Invalid granularity specification"):
6160
Float8DynamicActivationFloat8WeightConfig(granularity="invalid")
6261

6362
@unittest.skipIf(
64-
_DEVICE == "cuda" and not is_sm_at_least_89(),
63+
torch.cuda.is_available() and not is_sm_at_least_89(),
6564
"Requires GPU with compute capability >= 8.9",
6665
)
6766
def test_mismatched_granularity(self):
@@ -74,7 +73,7 @@ def test_mismatched_granularity(self):
7473
)
7574

7675
@unittest.skipIf(
77-
_DEVICE == "cuda" and not is_sm_at_least_89(),
76+
torch.cuda.is_available() and not is_sm_at_least_89(),
7877
"Requires GPU with compute capability >= 8.9",
7978
)
8079
def test_unsupported_granularity(self):
@@ -95,20 +94,21 @@ def test_per_row_with_float32(self):
9594
AssertionError,
9695
match="PerRow quantization only works for bfloat16 precision",
9796
):
98-
model = ToyLinearModel(64, 64).eval().to(torch.float32).to(_DEVICE)
97+
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
9998
quantize_(
10099
model,
101100
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
102101
)
103102

104103
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
105104
@unittest.skipIf(
106-
_DEVICE == "cuda" and not is_sm_at_least_89(),
105+
torch.cuda.is_available() and not is_sm_at_least_89(),
107106
"Requires GPU with compute capability >= 8.9",
108107
)
109108
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
110109
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
111110
def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype):
111+
_DEVICE = get_current_accelerator_device()
112112
block_size = ()
113113
device = _DEVICE
114114
input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32)
@@ -147,15 +147,15 @@ def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype):
147147

148148
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
149149
@unittest.skipIf(
150-
_DEVICE == "cuda" and not is_sm_at_least_89(),
150+
torch.cuda.is_available() == "cuda" and not is_sm_at_least_89(),
151151
"Requires GPU with compute capability >= 8.9",
152152
)
153153
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
154154
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
155155
@common_utils.parametrize("block_size", [(), (1, 32), (2, 16), (4, 8)])
156156
def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
157157
"""Test _dequantize_affine_float8 with various configurations"""
158-
158+
_DEVICE = get_current_accelerator_device()
159159
device = _DEVICE
160160
input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32)
161161

@@ -181,12 +181,12 @@ def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
181181

182182
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
183183
@unittest.skipIf(
184-
_DEVICE == "cuda" and not is_sm_at_least_89(),
184+
torch.cuda.is_available() and not is_sm_at_least_89(),
185185
"Requires GPU with compute capability >= 8.9",
186186
)
187187
def test_dequantize_affine_float8_scale_broadcasting(self):
188188
"""Test that scale broadcasting works correctly for block-wise quantization"""
189-
device = _DEVICE
189+
device = get_current_accelerator_device()
190190
# Create input tensor with known block structure
191191
input_tensor = torch.randn(4, 32, device=device, dtype=torch.float32)
192192
block_size = (2, 16) # 2x2 blocks in first dim, 2x16 blocks in second dim
@@ -314,7 +314,7 @@ def test_expected_kernels_on_gpu(self, granularity):
314314

315315
M, K, N = 128, 256, 512
316316
m = torch.nn.Sequential(
317-
torch.nn.Linear(K, N, device=_DEVICE, dtype=torch.bfloat16)
317+
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
318318
)
319319
config = Float8DynamicActivationFloat8WeightConfig(
320320
granularity=granularity,
@@ -327,7 +327,7 @@ def test_expected_kernels_on_gpu(self, granularity):
327327
)
328328

329329
m = torch.compile(m)
330-
x = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16)
330+
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
331331
out, code = run_and_get_code(m, x)
332332

333333
# triton kernel call looks like:

test/dtypes/test_bitpacking.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
bit_widths = (1, 2, 3, 4, 5, 6, 7)
1414
dimensions = (0, -1, 1)
15-
_DEVICE = get_current_accelerator_device()
1615

1716

1817
@pytest.fixture(autouse=True)
@@ -36,6 +35,7 @@ def test_CPU(bit_width, dim):
3635
@pytest.mark.parametrize("bit_width", bit_widths)
3736
@pytest.mark.parametrize("dim", dimensions)
3837
def test_GPU(bit_width, dim):
38+
_DEVICE = get_current_accelerator_device()
3939
test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).to(
4040
_DEVICE
4141
)
@@ -49,6 +49,7 @@ def test_GPU(bit_width, dim):
4949
@pytest.mark.parametrize("bit_width", bit_widths)
5050
@pytest.mark.parametrize("dim", dimensions)
5151
def test_compile(bit_width, dim):
52+
_DEVICE = get_current_accelerator_device()
5253
torch._dynamo.config.specialize_int = True
5354
torch.compile(pack, fullgraph=True)
5455
torch.compile(unpack, fullgraph=True)
@@ -63,6 +64,7 @@ def test_compile(bit_width, dim):
6364
# these test cases are for the example pack walk through in the bitpacking.py file
6465
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
6566
def test_pack_example():
67+
_DEVICE = get_current_accelerator_device()
6668
test_tensor = torch.tensor(
6769
[0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8
6870
).to(_DEVICE)

test/dtypes/test_nf4.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
logging.basicConfig(
5858
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
5959
)
60-
_DEVICE = get_current_accelerator_device()
6160

6261

6362
def _build_input_weight(embed_dim: int, device: torch.device, dtype: torch.dtype):
@@ -131,7 +130,7 @@ def test_backward_dtype_match(self, dtype: torch.dtype):
131130
def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype):
132131
# From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47
133132
torch.manual_seed(0)
134-
device = _DEVICE
133+
device = get_current_accelerator_device()
135134
embed_dim = 512
136135
input_weight = _build_input_weight(embed_dim, device, dtype)
137136
nf4_weight = to_nf4(input_weight)
@@ -161,12 +160,12 @@ def test_nf4_bnb_linear(self, dtype: torch.dtype):
161160
"""
162161
torch.manual_seed(0)
163162
dim = 512
164-
device = _DEVICE
163+
device = get_current_accelerator_device()
165164
input_weight = _build_input_weight(dim, device, dtype)
166165
nf4_weight = to_nf4(input_weight)
167166
bnb_linear = _build_bnb_linear(input_weight, device)
168167

169-
inp = torch.randn(2, 512, dtype=dtype, device=_DEVICE)
168+
inp = torch.randn(2, 512, dtype=dtype, device=device)
170169

171170
out_nf4 = linear_nf4(inp, nf4_weight).sum()
172171
out_bnb = bnb_linear(inp).sum()
@@ -181,6 +180,7 @@ def test_nf4_bnb_linear(self, dtype: torch.dtype):
181180
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
182181
def test_load_from_state_dicts(self, dtype: torch.dtype):
183182
"""Tests loading to and from different module state dicts"""
183+
_DEVICE = get_current_accelerator_device()
184184
input_tensor = torch.rand(64, device=_DEVICE, dtype=dtype)
185185
base_mod = self.TestMod(input_tensor, 32, 2)
186186

@@ -224,6 +224,7 @@ def test_to_copy(self, dtype: torch.dtype):
224224
torch.testing.assert_allclose(input_tensor, nf4_to_dtype, atol=0.13, rtol=0.13)
225225

226226
if torch.accelerator.is_available():
227+
_DEVICE = get_current_accelerator_device()
227228
input_tensor = torch.rand(128, device=_DEVICE)
228229
input_tensor_nf4 = to_nf4(input_tensor, 32, 2)
229230
nf4_to_dtype = input_tensor_nf4.to(dtype)
@@ -233,6 +234,7 @@ def test_to_copy(self, dtype: torch.dtype):
233234

234235
@unittest.skipIf(not torch.accelerator.is_available(), "Need gpu for test")
235236
def test_to_copy_device(self):
237+
_DEVICE = get_current_accelerator_device()
236238
input_tensor = torch.rand(128, device="cpu")
237239
t = to_nf4(input_tensor, 32, 2)
238240
assert t.device == torch.device("cpu")
@@ -256,6 +258,7 @@ def test_to_dtype(self, dtype: torch.dtype):
256258
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
257259
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
258260
def test_smoketest_linear(self, dtype: torch.dtype):
261+
_DEVICE = get_current_accelerator_device()
259262
a = torch.randn(32, 32, dtype=dtype, device=_DEVICE)
260263
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
261264
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
@@ -273,6 +276,7 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype):
273276
self.skipTest("test requires SM capability of at least (8, 0).")
274277
if version.parse(torch.__version__) < version.parse("2.3.0"):
275278
self.skipTest("test requires 2.3.0 and above for tracing NF4Tensor")
279+
_DEVICE = get_current_accelerator_device()
276280
a = torch.randn(32, 32, dtype=dtype, device=_DEVICE)
277281
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
278282
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
@@ -283,6 +287,7 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype):
283287
@parametrize("shape", [(16, 16), (32, 16)])
284288
@parametrize("chunk_size", [8, 16, 32])
285289
def test_chunk_size_equivalence(self, dtype: torch.dtype, shape, chunk_size):
290+
_DEVICE = get_current_accelerator_device()
286291
a = torch.randn(shape, device=_DEVICE, dtype=dtype)
287292
with unittest.mock.patch("torchao.dtypes.nf4tensor.CHUNK_SIZE", chunk_size):
288293
nf4_patched = to_nf4(a, 16, 2)
@@ -294,6 +299,7 @@ def test_chunk_size_equivalence(self, dtype: torch.dtype, shape, chunk_size):
294299
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
295300
@parametrize("input_size", [(512 * 512,), (512, 512)])
296301
def test_empty_like(self, input_size: Union[Tuple[int], int]):
302+
_DEVICE = get_current_accelerator_device()
297303
nf4_tensor = to_nf4(torch.rand(input_size, device=_DEVICE))
298304
new_tensor = torch.empty_like(nf4_tensor, device="cpu")
299305
self.assertTrue(isinstance(new_tensor, NF4Tensor))
@@ -303,6 +309,7 @@ def test_empty_like(self, input_size: Union[Tuple[int], int]):
303309
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
304310
@parametrize("compile", [False, True])
305311
def test_quantize_api(self, compile):
312+
_DEVICE = get_current_accelerator_device()
306313
nf4_linear = nn.Linear(512, 512, device=_DEVICE)
307314
torchao.quantize_(nf4_linear, nf4_weight_only())
308315
assert isinstance(nf4_linear.weight, NF4Tensor)
@@ -520,13 +527,15 @@ def test_pin_memory(self):
520527
nf4_tensor = nf4_tensor.pin_memory()
521528
self.assertTrue(nf4_tensor.is_pinned())
522529

530+
_DEVICE = get_current_accelerator_device()
523531
nf4_tensor = to_nf4(torch.randn(512 * 512, device=_DEVICE))
524532
self.assertFalse(nf4_tensor.is_pinned())
525533

526534
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
527535
def test_to_cuda(self):
528536
nf4_tensor = to_nf4(torch.randn(512 * 512))
529537
self.assertEqual(nf4_tensor.device.type, "cpu")
538+
_DEVICE = get_current_accelerator_device()
530539
nf4_tensor = nf4_tensor.to(_DEVICE, non_blocking=True)
531540
self.assertEqual(nf4_tensor.device.type, _DEVICE.type)
532541
self.assertEqual(type(nf4_tensor), NF4Tensor)
@@ -548,6 +557,7 @@ def test_to_cuda(self):
548557

549558
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
550559
def test_to_cpu(self):
560+
_DEVICE = get_current_accelerator_device()
551561
nf4_tensor = to_nf4(torch.randn(512 * 512, device=_DEVICE))
552562
nf4_tensor = nf4_tensor.cpu()
553563
self.assertEqual(nf4_tensor.device.type, "cpu")
@@ -562,6 +572,7 @@ def test_to_module(self):
562572
linear.weight = nn.Parameter(
563573
to_nf4(linear.weight.detach()), requires_grad=False
564574
)
575+
_DEVICE = get_current_accelerator_device()
565576
linear.to(_DEVICE)
566577
self.assertEqual(linear.weight.device.type, _DEVICE.type)
567578
weight = linear.weight.get_original_weight()
@@ -589,6 +600,7 @@ def test_to_module(self):
589600
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
590601
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
591602
def test_tensor_deepcopy(self, input_size: Union[Tuple[int], int]):
603+
_DEVICE = get_current_accelerator_device()
592604
nf4_orig = to_nf4(torch.randn(input_size, device=_DEVICE))
593605
nf4_clone = copy.deepcopy(nf4_orig)
594606
self.assertEqual(
@@ -679,6 +691,7 @@ def _test_qlora_fsdp2(
679691
dropout_p=0,
680692
)
681693
torch.manual_seed(42)
694+
_DEVICE = get_current_accelerator_device()
682695
with torch.device(_DEVICE):
683696
base_model = Transformer(model_args)
684697
for layer in base_model.layers:
@@ -768,6 +781,7 @@ def _test_comm(self, input_size: int):
768781
from torch.distributed._composable.fsdp import fully_shard
769782
from torch.distributed._tensor import distribute_tensor
770783

784+
_DEVICE = get_current_accelerator_device()
771785
model = nn.Linear(input_size, input_size, device=_DEVICE)
772786
origin_tensor = model.weight
773787
origin_nf4_tensor = to_nf4(origin_tensor)

0 commit comments

Comments
 (0)