Skip to content

Commit ac8b5ba

Browse files
committed
feat: add direct encoding method for float32 tensors
1 parent 4438a9e commit ac8b5ba

File tree

3 files changed

+197
-81
lines changed

3 files changed

+197
-81
lines changed

qdp/qdp-python/src/lib.rs

Lines changed: 148 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,16 @@ fn validate_cuda_tensor_for_encoding(
359359
let dtype_str: String = dtype.str()?.extract()?;
360360
let dtype_str_lower = dtype_str.to_ascii_lowercase();
361361
match method.as_str() {
362-
"amplitude" | "angle" => {
362+
"amplitude" => {
363+
if !(dtype_str_lower.contains("float64") || dtype_str_lower.contains("float32")) {
364+
return Err(PyRuntimeError::new_err(format!(
365+
"CUDA tensor must have dtype float64 or float32 for amplitude encoding, got {}. \
366+
Use tensor.to(torch.float64) or tensor.to(torch.float32)",
367+
dtype_str
368+
)));
369+
}
370+
}
371+
"angle" => {
363372
if !dtype_str_lower.contains("float64") {
364373
return Err(PyRuntimeError::new_err(format!(
365374
"CUDA tensor must have dtype float64 for {} encoding, got {}. \
@@ -715,76 +724,7 @@ impl QdpEngine {
715724
if is_pytorch_tensor(data)? {
716725
// Check if it's a CUDA tensor - use zero-copy GPU encoding
717726
if is_cuda_tensor(data)? {
718-
// Validate CUDA tensor for direct GPU encoding
719-
validate_cuda_tensor_for_encoding(
720-
data,
721-
self.engine.device().ordinal(),
722-
encoding_method,
723-
)?;
724-
725-
// Extract GPU pointer directly from PyTorch tensor
726-
let tensor_info = extract_cuda_tensor_info(data)?;
727-
let stream_ptr = get_torch_cuda_stream_ptr(data)?;
728-
729-
let ndim: usize = data.call_method0("dim")?.extract()?;
730-
731-
match ndim {
732-
1 => {
733-
// 1D CUDA tensor: single sample encoding
734-
let input_len = tensor_info.shape[0] as usize;
735-
// SAFETY: tensor_info.data_ptr was obtained via PyTorch's data_ptr() from a
736-
// valid CUDA tensor. The tensor remains alive during this call
737-
// (held by Python's GIL), and we validated dtype/contiguity/device above.
738-
let ptr = unsafe {
739-
self.engine
740-
.encode_from_gpu_ptr_with_stream(
741-
tensor_info.data_ptr as *const std::ffi::c_void,
742-
input_len,
743-
num_qubits,
744-
encoding_method,
745-
stream_ptr,
746-
)
747-
.map_err(|e| {
748-
PyRuntimeError::new_err(format!("Encoding failed: {}", e))
749-
})?
750-
};
751-
return Ok(QuantumTensor {
752-
ptr,
753-
consumed: false,
754-
});
755-
}
756-
2 => {
757-
// 2D CUDA tensor: batch encoding
758-
let num_samples = tensor_info.shape[0] as usize;
759-
let sample_size = tensor_info.shape[1] as usize;
760-
// SAFETY: Same as above - pointer from validated PyTorch CUDA tensor
761-
let ptr = unsafe {
762-
self.engine
763-
.encode_batch_from_gpu_ptr_with_stream(
764-
tensor_info.data_ptr as *const std::ffi::c_void,
765-
num_samples,
766-
sample_size,
767-
num_qubits,
768-
encoding_method,
769-
stream_ptr,
770-
)
771-
.map_err(|e| {
772-
PyRuntimeError::new_err(format!("Encoding failed: {}", e))
773-
})?
774-
};
775-
return Ok(QuantumTensor {
776-
ptr,
777-
consumed: false,
778-
});
779-
}
780-
_ => {
781-
return Err(PyRuntimeError::new_err(format!(
782-
"Unsupported CUDA tensor shape: {}D. Expected 1D tensor for single \
783-
sample encoding or 2D tensor (batch_size, features) for batch encoding.",
784-
ndim
785-
)));
786-
}
787-
}
727+
return self._encode_from_cuda_tensor(data, num_qubits, encoding_method);
788728
}
789729
// CPU PyTorch tensor path
790730
return self.encode_from_pytorch(data, num_qubits, encoding_method);
@@ -1213,6 +1153,143 @@ impl QdpEngine {
12131153
.run_dual_stream_encode(&data_slice, num_qubits, encoding_method)
12141154
.map_err(|e| PyRuntimeError::new_err(format!("run_dual_stream_encode failed: {}", e)))
12151155
}
1156+
1157+
/// encode directly from a PyTorch CUDA tensor. Internal helper.
1158+
///
1159+
/// Dispatches to the core f32 GPU pointer API for 1D float32 amplitude encoding,
1160+
/// or to the float64/basis GPU pointer APIs for other dtypes and batch encoding.
1161+
///
1162+
/// Args:
1163+
/// data: PyTorch CUDA tensor
1164+
/// num_qubits: Number of qubits
1165+
/// encoding_method: Encoding strategy (currently only "amplitude")
1166+
fn _encode_from_cuda_tensor(
1167+
&self,
1168+
data: &Bound<'_, PyAny>,
1169+
num_qubits: usize,
1170+
encoding_method: &str,
1171+
) -> PyResult<QuantumTensor> {
1172+
// Validate CUDA tensor for direct GPU encoding (shape, contiguity, device, dtype)
1173+
validate_cuda_tensor_for_encoding(data, self.engine.device().ordinal(), encoding_method)?;
1174+
1175+
// Determine dtype for dispatch (float32 vs float64, etc.).
1176+
let dtype = data.getattr("dtype")?;
1177+
let dtype_str: String = dtype.str()?.extract()?;
1178+
let dtype_str_lower = dtype_str.to_ascii_lowercase();
1179+
let is_f32 = dtype_str_lower.contains("float32");
1180+
let method = encoding_method.to_ascii_lowercase();
1181+
1182+
// Current f32 CUDA path only supports amplitude encoding for 1D tensors.
1183+
let ndim: usize = data.call_method0("dim")?.extract()?;
1184+
1185+
if method.as_str() == "amplitude" && is_f32 {
1186+
match ndim {
1187+
1 => {
1188+
// 1D CUDA tensor, float32 amplitude encoding using core f32 GPU pointer API.
1189+
let input_len: usize = data.call_method0("numel")?.extract()?;
1190+
if input_len == 0 {
1191+
return Err(PyRuntimeError::new_err("CUDA tensor cannot be empty"));
1192+
}
1193+
1194+
let stream_ptr = get_torch_cuda_stream_ptr(data)?;
1195+
let data_ptr_u64: u64 = data.call_method0("data_ptr")?.extract()?;
1196+
if data_ptr_u64 == 0 {
1197+
return Err(PyRuntimeError::new_err(
1198+
"PyTorch returned a null data pointer for CUDA tensor",
1199+
));
1200+
}
1201+
let data_ptr = data_ptr_u64 as *const f32;
1202+
1203+
let ptr = unsafe {
1204+
self.engine
1205+
.encode_from_gpu_ptr_f32_with_stream(
1206+
data_ptr, input_len, num_qubits, stream_ptr,
1207+
)
1208+
.map_err(|e| {
1209+
PyRuntimeError::new_err(format!(
1210+
"Encoding failed (float32 amplitude): {}",
1211+
e
1212+
))
1213+
})?
1214+
};
1215+
1216+
Ok(QuantumTensor {
1217+
ptr,
1218+
consumed: false,
1219+
})
1220+
}
1221+
2 => Err(PyRuntimeError::new_err(
1222+
"CUDA float32 batch amplitude encoding is not yet supported. \
1223+
Use float64 (tensor.to(torch.float64)) or encode samples individually.",
1224+
)),
1225+
_ => Err(PyRuntimeError::new_err(format!(
1226+
"Unsupported CUDA tensor shape: {}D. Expected 1D tensor for single \
1227+
sample encoding or 2D tensor (batch_size, features) for batch encoding.",
1228+
ndim
1229+
))),
1230+
}
1231+
} else {
1232+
// Existing float64 (and basis/int64) CUDA path using direct GPU pointer.
1233+
let tensor_info = extract_cuda_tensor_info(data)?;
1234+
let stream_ptr = get_torch_cuda_stream_ptr(data)?;
1235+
1236+
match ndim {
1237+
1 => {
1238+
// 1D CUDA tensor: single sample encoding
1239+
let input_len = tensor_info.shape[0] as usize;
1240+
// SAFETY: tensor_info.data_ptr was obtained via PyTorch's data_ptr() from a
1241+
// valid CUDA tensor. The tensor remains alive during this call
1242+
// (held by Python's GIL), and we validated dtype/contiguity/device above.
1243+
let ptr = unsafe {
1244+
self.engine
1245+
.encode_from_gpu_ptr_with_stream(
1246+
tensor_info.data_ptr as *const std::ffi::c_void,
1247+
input_len,
1248+
num_qubits,
1249+
encoding_method,
1250+
stream_ptr,
1251+
)
1252+
.map_err(|e| {
1253+
PyRuntimeError::new_err(format!("Encoding failed: {}", e))
1254+
})?
1255+
};
1256+
Ok(QuantumTensor {
1257+
ptr,
1258+
consumed: false,
1259+
})
1260+
}
1261+
2 => {
1262+
// 2D CUDA tensor: batch encoding
1263+
let num_samples = tensor_info.shape[0] as usize;
1264+
let sample_size = tensor_info.shape[1] as usize;
1265+
// SAFETY: Same as above - pointer from validated PyTorch CUDA tensor
1266+
let ptr = unsafe {
1267+
self.engine
1268+
.encode_batch_from_gpu_ptr_with_stream(
1269+
tensor_info.data_ptr as *const std::ffi::c_void,
1270+
num_samples,
1271+
sample_size,
1272+
num_qubits,
1273+
encoding_method,
1274+
stream_ptr,
1275+
)
1276+
.map_err(|e| {
1277+
PyRuntimeError::new_err(format!("Encoding failed: {}", e))
1278+
})?
1279+
};
1280+
Ok(QuantumTensor {
1281+
ptr,
1282+
consumed: false,
1283+
})
1284+
}
1285+
_ => Err(PyRuntimeError::new_err(format!(
1286+
"Unsupported CUDA tensor shape: {}D. Expected 1D tensor for single \
1287+
sample encoding or 2D tensor (batch_size, features) for batch encoding.",
1288+
ndim
1289+
))),
1290+
}
1291+
}
1292+
}
12161293
}
12171294

12181295
/// Runs the full throughput pipeline in Rust with GIL released. Returns (duration_sec, vectors_per_sec, latency_ms_per_vector).

qdp/qdp-python/tests/test_dlpack_validation.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,30 @@ def _engine():
3838

3939

4040
@pytest.mark.skipif(not _cuda_available(), reason="CUDA not available")
41-
def test_dtype_validation_float32_rejected():
42-
"""DLPack tensor must be float64; float32 CUDA tensor should fail with clear message."""
41+
def test_cuda_float32_amplitude_supported():
42+
"""1D float32 CUDA tensor should be supported for amplitude encoding via GPU pointer f32 path."""
4343
engine = _engine()
4444
# 1D float32 CUDA tensor (contiguous)
4545
t = torch.randn(4, dtype=torch.float32, device="cuda")
46-
with pytest.raises(RuntimeError) as exc_info:
46+
result = engine.encode(t, num_qubits=2, encoding_method="amplitude")
47+
assert result is not None
48+
49+
# Verify DLPack round-trip works and tensor is on CUDA
50+
qt = torch.from_dlpack(result)
51+
assert qt.is_cuda
52+
# With default engine precision=float32, complex64 is expected
53+
assert qt.dtype in (torch.complex64, torch.complex128)
54+
55+
56+
@pytest.mark.skipif(not _cuda_available(), reason="CUDA not available")
57+
def test_cuda_float32_amplitude_2d_unsupported():
58+
"""2D float32 CUDA tensor with amplitude encoding should raise a clear error."""
59+
engine = _engine()
60+
t = torch.randn(2, 4, dtype=torch.float32, device="cuda")
61+
with pytest.raises(
62+
RuntimeError, match="float32 batch amplitude encoding is not yet supported"
63+
):
4764
engine.encode(t, num_qubits=2, encoding_method="amplitude")
48-
msg = str(exc_info.value).lower()
49-
assert "float64" in msg
50-
assert "code=" in msg or "bits=" in msg or "lanes=" in msg
5165

5266

5367
@pytest.mark.skipif(not _cuda_available(), reason="CUDA not available")

testing/qdp/test_bindings.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def test_encode_cuda_tensor(data_shape, expected_shape, expected_batch_size):
315315
@requires_qdp
316316
@pytest.mark.gpu
317317
def test_encode_cuda_tensor_wrong_dtype():
318-
"""Test error when CUDA tensor has wrong dtype (non-float64)."""
318+
"""Test error when CUDA tensor has wrong dtype for amplitude (e.g. float16)."""
319319
pytest.importorskip("torch")
320320
from _qdp import QdpEngine
321321

@@ -324,9 +324,9 @@ def test_encode_cuda_tensor_wrong_dtype():
324324

325325
engine = QdpEngine(0)
326326

327-
# Create CUDA tensor with float32 dtype (wrong)
328-
data = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32, device="cuda:0")
329-
with pytest.raises(RuntimeError, match="CUDA tensor must have dtype float64"):
327+
# Amplitude encoding accepts float64 or float32 only; float16 is invalid
328+
data = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float16, device="cuda:0")
329+
with pytest.raises(RuntimeError, match="float64 or float32"):
330330
engine.encode(data, 2, "amplitude")
331331

332332

@@ -537,6 +537,31 @@ def test_encode_cuda_tensor_output_dtype(precision, expected_dtype):
537537
)
538538

539539

540+
@requires_qdp
541+
@pytest.mark.gpu
542+
@pytest.mark.parametrize(
543+
"precision,expected_dtype",
544+
[
545+
("float32", torch.complex64),
546+
("float64", torch.complex128),
547+
],
548+
)
549+
def test_encode_cuda_tensor_float32_input_output_dtype(precision, expected_dtype):
550+
"""Test that 1D float32 CUDA amplitude encoding respects engine precision (f32 path)."""
551+
pytest.importorskip("torch")
552+
from _qdp import QdpEngine
553+
554+
if not torch.cuda.is_available():
555+
pytest.skip("GPU required for QdpEngine")
556+
557+
engine = QdpEngine(0, precision=precision)
558+
data = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32, device="cuda:0")
559+
result = torch.from_dlpack(engine.encode(data, 2, "amplitude"))
560+
assert result.dtype == expected_dtype, (
561+
f"Expected {expected_dtype}, got {result.dtype}"
562+
)
563+
564+
540565
@requires_qdp
541566
@pytest.mark.gpu
542567
def test_basis_encode_basic():

0 commit comments

Comments
 (0)