1616
1717"""tests for Quantum Data Loader."""
1818
19+ from unittest .mock import patch
20+
21+ import numpy as np
1922import pytest
2023
2124try :
@@ -28,6 +31,15 @@ def _loader_available():
2831 return QuantumDataLoader is not None
2932
3033
34+ def _cuda_available ():
35+ try :
36+ import torch
37+
38+ return torch .cuda .is_available ()
39+ except ImportError :
40+ return False
41+
42+
3143@pytest .mark .skipif (not _loader_available (), reason = "QuantumDataLoader not available" )
3244def test_mutual_exclusion_both_sources_raises () -> None :
3345 """Calling both .source_synthetic() and .source_file() then __iter__ raises ValueError."""
@@ -238,3 +250,134 @@ def test_source_file_s3_streaming_non_parquet_raises(path):
238250 )
239251 msg = str (exc_info .value ).lower ()
240252 assert "parquet" in msg or "streaming" in msg
253+
254+
255+ # --- as_torch() / as_numpy() output format tests ---
256+
257+
258+ @pytest .mark .skipif (not _loader_available (), reason = "QuantumDataLoader not available" )
259+ def test_as_torch_raises_at_config_time_when_torch_missing ():
260+ """as_torch() raises RuntimeError immediately (config time) when torch is not installed."""
261+ with patch ("qumat_qdp.loader._torch" , None ):
262+ loader = QuantumDataLoader (device_id = 0 ).qubits (4 ).batches (2 , size = 4 )
263+ with pytest .raises (RuntimeError ) as exc_info :
264+ loader .as_torch ()
265+ msg = str (exc_info .value )
266+ assert "PyTorch" in msg or "torch" in msg .lower ()
267+ assert "pip install" in msg
268+
269+
270+ @pytest .mark .skipif (not _loader_available (), reason = "QuantumDataLoader not available" )
271+ def test_as_numpy_succeeds_at_config_time_without_torch ():
272+ """as_numpy() does not raise at config time even when torch is not installed."""
273+ with patch ("qumat_qdp.loader._torch" , None ):
274+ loader = (
275+ QuantumDataLoader (device_id = 0 )
276+ .qubits (4 )
277+ .batches (2 , size = 4 )
278+ .source_synthetic ()
279+ .as_numpy ()
280+ )
281+ assert loader ._output_format == ("numpy" ,)
282+
283+
284+ @pytest .mark .skipif (not _loader_available (), reason = "QuantumDataLoader not available" )
285+ @pytest .mark .skipif (not _cuda_available (), reason = "CUDA GPU required" )
286+ def test_as_numpy_yields_float64_arrays ():
287+ """as_numpy() yields numpy float64 arrays with correct shape; no torch required."""
288+ num_qubits = 4
289+ batch_size = 8
290+ state_len = 2 ** num_qubits # 16
291+
292+ batches = []
293+ with patch ("qumat_qdp.loader._torch" , None ):
294+ loader = (
295+ QuantumDataLoader (device_id = 0 )
296+ .qubits (num_qubits )
297+ .batches (3 , size = batch_size )
298+ .source_synthetic ()
299+ .as_numpy ()
300+ )
301+ for batch in loader :
302+ batches .append (batch )
303+
304+ assert len (batches ) == 3
305+ for batch in batches :
306+ assert isinstance (batch , np .ndarray ), f"expected ndarray, got { type (batch )} "
307+ assert batch .dtype == np .float64 , f"expected float64, got { batch .dtype } "
308+ assert batch .ndim == 2
309+ assert batch .shape == (batch_size , state_len ), f"unexpected shape { batch .shape } "
310+
311+
312+ @pytest .mark .skipif (not _loader_available (), reason = "QuantumDataLoader not available" )
313+ @pytest .mark .skipif (not _cuda_available (), reason = "CUDA GPU required" )
314+ def test_as_numpy_amplitudes_are_unit_norm ():
315+ """Each row from as_numpy() should be a unit-norm state vector (amplitude encoding)."""
316+ num_qubits = 4
317+ batch_size = 16
318+
319+ loader = (
320+ QuantumDataLoader (device_id = 0 )
321+ .qubits (num_qubits )
322+ .batches (2 , size = batch_size )
323+ .source_synthetic ()
324+ .as_numpy ()
325+ )
326+ for batch in loader :
327+ arr = np .asarray (batch , dtype = np .float64 )
328+ norms = np .linalg .norm (arr , axis = 1 )
329+ np .testing .assert_allclose (norms , 1.0 , atol = 1e-5 )
330+
331+
332+ @pytest .mark .skipif (not _loader_available (), reason = "QuantumDataLoader not available" )
333+ @pytest .mark .skipif (not _cuda_available (), reason = "CUDA GPU required" )
334+ def test_as_torch_yields_cuda_tensors ():
335+ """as_torch(device='cuda') yields torch tensors on CUDA."""
336+ try :
337+ import torch
338+ except ImportError :
339+ pytest .skip ("torch not installed" )
340+
341+ num_qubits = 4
342+ batch_size = 8
343+ state_len = 2 ** num_qubits
344+
345+ loader = (
346+ QuantumDataLoader (device_id = 0 )
347+ .qubits (num_qubits )
348+ .batches (2 , size = batch_size )
349+ .source_synthetic ()
350+ .as_torch (device = "cuda" )
351+ )
352+ for batch in loader :
353+ assert isinstance (batch , torch .Tensor )
354+ assert batch .is_cuda
355+ assert batch .shape == (batch_size , state_len )
356+
357+
358+ @pytest .mark .skipif (not _loader_available (), reason = "QuantumDataLoader not available" )
359+ @pytest .mark .skipif (not _cuda_available (), reason = "CUDA GPU required" )
360+ def test_as_numpy_from_source_array ():
361+ """as_numpy() works with source_array(), yielding correct shapes and dtype."""
362+ num_qubits = 3
363+ state_len = 2 ** num_qubits # 8
364+ n_samples = 12
365+ batch_size = 4
366+
367+ rng = np .random .default_rng (42 )
368+ X = rng .standard_normal ((n_samples , state_len ))
369+
370+ loader = (
371+ QuantumDataLoader (device_id = 0 )
372+ .qubits (num_qubits )
373+ .batches (1 , size = batch_size )
374+ .encoding ("amplitude" )
375+ .source_array (X )
376+ .as_numpy ()
377+ )
378+ batches = list (loader )
379+ assert len (batches ) == n_samples // batch_size
380+ for batch in batches :
381+ assert isinstance (batch , np .ndarray )
382+ assert batch .dtype == np .float64
383+ assert batch .shape [1 ] == state_len
0 commit comments