1616
1717"""tests for Quantum Data Loader."""
1818
19+ from unittest .mock import patch
20+
21+ import numpy as np
1922import pytest
2023
2124try :
@@ -28,6 +31,15 @@ def _loader_available():
2831 return QuantumDataLoader is not None
2932
3033
34+ def _cuda_available ():
35+ try :
36+ import torch
37+
38+ return torch .cuda .is_available ()
39+ except ImportError :
40+ return False
41+
42+
3143@pytest .mark .skipif (not _loader_available (), reason = "QuantumDataLoader not available" )
3244def test_mutual_exclusion_both_sources_raises ():
3345 """Calling both .source_synthetic() and .source_file() then __iter__ raises ValueError."""
@@ -184,3 +196,134 @@ def test_null_handling_default_is_none():
184196 """By default, _null_handling is None (Rust will use FillZero)."""
185197 loader = QuantumDataLoader (device_id = 0 )
186198 assert loader ._null_handling is None
199+
200+
201+ # --- as_torch() / as_numpy() output format tests ---
202+
203+
204+ @pytest .mark .skipif (not _loader_available (), reason = "QuantumDataLoader not available" )
205+ def test_as_torch_raises_at_config_time_when_torch_missing ():
206+ """as_torch() raises RuntimeError immediately (config time) when torch is not installed."""
207+ with patch ("qumat_qdp.loader._torch" , None ):
208+ loader = QuantumDataLoader (device_id = 0 ).qubits (4 ).batches (2 , size = 4 )
209+ with pytest .raises (RuntimeError ) as exc_info :
210+ loader .as_torch ()
211+ msg = str (exc_info .value )
212+ assert "PyTorch" in msg or "torch" in msg .lower ()
213+ assert "pip install" in msg
214+
215+
216+ @pytest .mark .skipif (not _loader_available (), reason = "QuantumDataLoader not available" )
217+ def test_as_numpy_succeeds_at_config_time_without_torch ():
218+ """as_numpy() does not raise at config time even when torch is not installed."""
219+ with patch ("qumat_qdp.loader._torch" , None ):
220+ loader = (
221+ QuantumDataLoader (device_id = 0 )
222+ .qubits (4 )
223+ .batches (2 , size = 4 )
224+ .source_synthetic ()
225+ .as_numpy ()
226+ )
227+ assert loader ._output_format == ("numpy" ,)
228+
229+
230+ @pytest .mark .skipif (not _loader_available (), reason = "QuantumDataLoader not available" )
231+ @pytest .mark .skipif (not _cuda_available (), reason = "CUDA GPU required" )
232+ def test_as_numpy_yields_float64_arrays ():
233+ """as_numpy() yields numpy float64 arrays with correct shape; no torch required."""
234+ num_qubits = 4
235+ batch_size = 8
236+ state_len = 2 ** num_qubits # 16
237+
238+ batches = []
239+ with patch ("qumat_qdp.loader._torch" , None ):
240+ loader = (
241+ QuantumDataLoader (device_id = 0 )
242+ .qubits (num_qubits )
243+ .batches (3 , size = batch_size )
244+ .source_synthetic ()
245+ .as_numpy ()
246+ )
247+ for batch in loader :
248+ batches .append (batch )
249+
250+ assert len (batches ) == 3
251+ for batch in batches :
252+ assert isinstance (batch , np .ndarray ), f"expected ndarray, got { type (batch )} "
253+ assert batch .dtype == np .float64 , f"expected float64, got { batch .dtype } "
254+ assert batch .ndim == 2
255+ assert batch .shape == (batch_size , state_len ), f"unexpected shape { batch .shape } "
256+
257+
258+ @pytest .mark .skipif (not _loader_available (), reason = "QuantumDataLoader not available" )
259+ @pytest .mark .skipif (not _cuda_available (), reason = "CUDA GPU required" )
260+ def test_as_numpy_amplitudes_are_unit_norm ():
261+ """Each row from as_numpy() should be a unit-norm state vector (amplitude encoding)."""
262+ num_qubits = 4
263+ batch_size = 16
264+
265+ loader = (
266+ QuantumDataLoader (device_id = 0 )
267+ .qubits (num_qubits )
268+ .batches (2 , size = batch_size )
269+ .source_synthetic ()
270+ .as_numpy ()
271+ )
272+ for batch in loader :
273+ arr = np .asarray (batch , dtype = np .float64 )
274+ norms = np .linalg .norm (arr , axis = 1 )
275+ np .testing .assert_allclose (norms , 1.0 , atol = 1e-5 )
276+
277+
278+ @pytest .mark .skipif (not _loader_available (), reason = "QuantumDataLoader not available" )
279+ @pytest .mark .skipif (not _cuda_available (), reason = "CUDA GPU required" )
280+ def test_as_torch_yields_cuda_tensors ():
281+ """as_torch(device='cuda') yields torch tensors on CUDA."""
282+ try :
283+ import torch
284+ except ImportError :
285+ pytest .skip ("torch not installed" )
286+
287+ num_qubits = 4
288+ batch_size = 8
289+ state_len = 2 ** num_qubits
290+
291+ loader = (
292+ QuantumDataLoader (device_id = 0 )
293+ .qubits (num_qubits )
294+ .batches (2 , size = batch_size )
295+ .source_synthetic ()
296+ .as_torch (device = "cuda" )
297+ )
298+ for batch in loader :
299+ assert isinstance (batch , torch .Tensor )
300+ assert batch .is_cuda
301+ assert batch .shape == (batch_size , state_len )
302+
303+
304+ @pytest .mark .skipif (not _loader_available (), reason = "QuantumDataLoader not available" )
305+ @pytest .mark .skipif (not _cuda_available (), reason = "CUDA GPU required" )
306+ def test_as_numpy_from_source_array ():
307+ """as_numpy() works with source_array(), yielding correct shapes and dtype."""
308+ num_qubits = 3
309+ state_len = 2 ** num_qubits # 8
310+ n_samples = 12
311+ batch_size = 4
312+
313+ rng = np .random .default_rng (42 )
314+ X = rng .standard_normal ((n_samples , state_len ))
315+
316+ loader = (
317+ QuantumDataLoader (device_id = 0 )
318+ .qubits (num_qubits )
319+ .batches (1 , size = batch_size )
320+ .encoding ("amplitude" )
321+ .source_array (X )
322+ .as_numpy ()
323+ )
324+ batches = list (loader )
325+ assert len (batches ) == n_samples // batch_size
326+ for batch in batches :
327+ assert isinstance (batch , np .ndarray )
328+ assert batch .dtype == np .float64
329+ assert batch .shape [1 ] == state_len
0 commit comments