5757logging .basicConfig (
5858 format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" , level = logging .INFO
5959)
60- _DEVICE = get_current_accelerator_device ()
6160
6261
6362def _build_input_weight (embed_dim : int , device : torch .device , dtype : torch .dtype ):
@@ -131,7 +130,7 @@ def test_backward_dtype_match(self, dtype: torch.dtype):
131130 def test_reconstruction_qlora_vs_bnb (self , dtype : torch .dtype ):
132131 # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47
133132 torch .manual_seed (0 )
134- device = _DEVICE
133+ device = get_current_accelerator_device ()
135134 embed_dim = 512
136135 input_weight = _build_input_weight (embed_dim , device , dtype )
137136 nf4_weight = to_nf4 (input_weight )
@@ -161,12 +160,12 @@ def test_nf4_bnb_linear(self, dtype: torch.dtype):
161160 """
162161 torch .manual_seed (0 )
163162 dim = 512
164- device = _DEVICE
163+ device = get_current_accelerator_device ()
165164 input_weight = _build_input_weight (dim , device , dtype )
166165 nf4_weight = to_nf4 (input_weight )
167166 bnb_linear = _build_bnb_linear (input_weight , device )
168167
169- inp = torch .randn (2 , 512 , dtype = dtype , device = _DEVICE )
168+ inp = torch .randn (2 , 512 , dtype = dtype , device = device )
170169
171170 out_nf4 = linear_nf4 (inp , nf4_weight ).sum ()
172171 out_bnb = bnb_linear (inp ).sum ()
@@ -181,6 +180,7 @@ def test_nf4_bnb_linear(self, dtype: torch.dtype):
181180 @parametrize ("dtype" , [torch .bfloat16 , torch .float16 , torch .float32 ])
182181 def test_load_from_state_dicts (self , dtype : torch .dtype ):
183182 """Tests loading to and from different module state dicts"""
183+ _DEVICE = get_current_accelerator_device ()
184184 input_tensor = torch .rand (64 , device = _DEVICE , dtype = dtype )
185185 base_mod = self .TestMod (input_tensor , 32 , 2 )
186186
@@ -224,6 +224,7 @@ def test_to_copy(self, dtype: torch.dtype):
224224 torch .testing .assert_allclose (input_tensor , nf4_to_dtype , atol = 0.13 , rtol = 0.13 )
225225
226226 if torch .accelerator .is_available ():
227+ _DEVICE = get_current_accelerator_device ()
227228 input_tensor = torch .rand (128 , device = _DEVICE )
228229 input_tensor_nf4 = to_nf4 (input_tensor , 32 , 2 )
229230 nf4_to_dtype = input_tensor_nf4 .to (dtype )
@@ -233,6 +234,7 @@ def test_to_copy(self, dtype: torch.dtype):
233234
234235 @unittest .skipIf (not torch .accelerator .is_available (), "Need gpu for test" )
235236 def test_to_copy_device (self ):
237+ _DEVICE = get_current_accelerator_device ()
236238 input_tensor = torch .rand (128 , device = "cpu" )
237239 t = to_nf4 (input_tensor , 32 , 2 )
238240 assert t .device == torch .device ("cpu" )
@@ -256,6 +258,7 @@ def test_to_dtype(self, dtype: torch.dtype):
256258 @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
257259 @parametrize ("dtype" , [torch .bfloat16 , torch .float16 , torch .float32 ])
258260 def test_smoketest_linear (self , dtype : torch .dtype ):
261+ _DEVICE = get_current_accelerator_device ()
259262 a = torch .randn (32 , 32 , dtype = dtype , device = _DEVICE )
260263 a_nf4 = torchao .dtypes .to_nf4 (a , 16 , 2 )
261264 inp = torch .randn (2 , 32 , 32 , dtype = a .dtype , device = a .device )
@@ -273,6 +276,7 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype):
273276 self .skipTest ("test requires SM capability of at least (8, 0)." )
274277 if version .parse (torch .__version__ ) < version .parse ("2.3.0" ):
275278 self .skipTest ("test requires 2.3.0 and above for tracing NF4Tensor" )
279+ _DEVICE = get_current_accelerator_device ()
276280 a = torch .randn (32 , 32 , dtype = dtype , device = _DEVICE )
277281 a_nf4 = torchao .dtypes .to_nf4 (a , 16 , 2 )
278282 inp = torch .randn (2 , 32 , 32 , dtype = a .dtype , device = a .device )
@@ -283,6 +287,7 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype):
283287 @parametrize ("shape" , [(16 , 16 ), (32 , 16 )])
284288 @parametrize ("chunk_size" , [8 , 16 , 32 ])
285289 def test_chunk_size_equivalence (self , dtype : torch .dtype , shape , chunk_size ):
290+ _DEVICE = get_current_accelerator_device ()
286291 a = torch .randn (shape , device = _DEVICE , dtype = dtype )
287292 with unittest .mock .patch ("torchao.dtypes.nf4tensor.CHUNK_SIZE" , chunk_size ):
288293 nf4_patched = to_nf4 (a , 16 , 2 )
@@ -294,6 +299,7 @@ def test_chunk_size_equivalence(self, dtype: torch.dtype, shape, chunk_size):
294299 @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
295300 @parametrize ("input_size" , [(512 * 512 ,), (512 , 512 )])
296301 def test_empty_like (self , input_size : Union [Tuple [int ], int ]):
302+ _DEVICE = get_current_accelerator_device ()
297303 nf4_tensor = to_nf4 (torch .rand (input_size , device = _DEVICE ))
298304 new_tensor = torch .empty_like (nf4_tensor , device = "cpu" )
299305 self .assertTrue (isinstance (new_tensor , NF4Tensor ))
@@ -303,6 +309,7 @@ def test_empty_like(self, input_size: Union[Tuple[int], int]):
303309 @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
304310 @parametrize ("compile" , [False , True ])
305311 def test_quantize_api (self , compile ):
312+ _DEVICE = get_current_accelerator_device ()
306313 nf4_linear = nn .Linear (512 , 512 , device = _DEVICE )
307314 torchao .quantize_ (nf4_linear , nf4_weight_only ())
308315 assert isinstance (nf4_linear .weight , NF4Tensor )
@@ -520,13 +527,15 @@ def test_pin_memory(self):
520527 nf4_tensor = nf4_tensor .pin_memory ()
521528 self .assertTrue (nf4_tensor .is_pinned ())
522529
530+ _DEVICE = get_current_accelerator_device ()
523531 nf4_tensor = to_nf4 (torch .randn (512 * 512 , device = _DEVICE ))
524532 self .assertFalse (nf4_tensor .is_pinned ())
525533
526534 @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
527535 def test_to_cuda (self ):
528536 nf4_tensor = to_nf4 (torch .randn (512 * 512 ))
529537 self .assertEqual (nf4_tensor .device .type , "cpu" )
538+ _DEVICE = get_current_accelerator_device ()
530539 nf4_tensor = nf4_tensor .to (_DEVICE , non_blocking = True )
531540 self .assertEqual (nf4_tensor .device .type , _DEVICE .type )
532541 self .assertEqual (type (nf4_tensor ), NF4Tensor )
@@ -548,6 +557,7 @@ def test_to_cuda(self):
548557
549558 @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
550559 def test_to_cpu (self ):
560+ _DEVICE = get_current_accelerator_device ()
551561 nf4_tensor = to_nf4 (torch .randn (512 * 512 , device = _DEVICE ))
552562 nf4_tensor = nf4_tensor .cpu ()
553563 self .assertEqual (nf4_tensor .device .type , "cpu" )
@@ -562,6 +572,7 @@ def test_to_module(self):
562572 linear .weight = nn .Parameter (
563573 to_nf4 (linear .weight .detach ()), requires_grad = False
564574 )
575+ _DEVICE = get_current_accelerator_device ()
565576 linear .to (_DEVICE )
566577 self .assertEqual (linear .weight .device .type , _DEVICE .type )
567578 weight = linear .weight .get_original_weight ()
@@ -589,6 +600,7 @@ def test_to_module(self):
589600 @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
590601 @parametrize ("input_size" , [512 * 512 , (512 * 512 ,), (512 , 512 )])
591602 def test_tensor_deepcopy (self , input_size : Union [Tuple [int ], int ]):
603+ _DEVICE = get_current_accelerator_device ()
592604 nf4_orig = to_nf4 (torch .randn (input_size , device = _DEVICE ))
593605 nf4_clone = copy .deepcopy (nf4_orig )
594606 self .assertEqual (
@@ -679,6 +691,7 @@ def _test_qlora_fsdp2(
679691 dropout_p = 0 ,
680692 )
681693 torch .manual_seed (42 )
694+ _DEVICE = get_current_accelerator_device ()
682695 with torch .device (_DEVICE ):
683696 base_model = Transformer (model_args )
684697 for layer in base_model .layers :
@@ -768,6 +781,7 @@ def _test_comm(self, input_size: int):
768781 from torch .distributed ._composable .fsdp import fully_shard
769782 from torch .distributed ._tensor import distribute_tensor
770783
784+ _DEVICE = get_current_accelerator_device ()
771785 model = nn .Linear (input_size , input_size , device = _DEVICE )
772786 origin_tensor = model .weight
773787 origin_nf4_tensor = to_nf4 (origin_tensor )
0 commit comments