Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions qdp/qdp-core/examples/dataloader_throughput.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::thread;
use std::time::{Duration, Instant};

use qdp_core::QdpEngine;
use qdp_core::dlpack::free_dlpack_tensor;

const BATCH_SIZE: usize = 64;
const VECTOR_LEN: usize = 1024; // 2^10
Expand Down Expand Up @@ -99,12 +100,15 @@ fn main() {
debug_assert_eq!(batch.len() % VECTOR_LEN, 0);
let num_samples = batch.len() / VECTOR_LEN;
match engine.encode_batch(&batch, num_samples, VECTOR_LEN, NUM_QUBITS, "amplitude") {
Ok(ptr) => unsafe {
let managed = &mut *ptr;
if let Some(deleter) = managed.deleter.take() {
deleter(ptr);
Ok(ptr) => {
if let Err(e) = unsafe { free_dlpack_tensor(ptr) } {
eprintln!(
"Failed to free DLPack tensor for batch {} (processed {} vectors): {:?}",
batch_idx, total_vectors, e
);
return;
}
},
}
Err(e) => {
eprintln!(
"Encode batch failed on batch {} (processed {} vectors): {:?}",
Expand Down
12 changes: 5 additions & 7 deletions qdp/qdp-core/examples/nvtx_profile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
// Run: cargo run -p qdp-core --example nvtx_profile --features observability --release

use qdp_core::QdpEngine;
use qdp_core::dlpack::free_dlpack_tensor;

fn main() {
println!("=== NVTX Profiling Example ===");
Expand Down Expand Up @@ -61,13 +62,10 @@ fn main() {
println!("✓ Encoding succeeded");
println!("✓ DLPack pointer: {:p}", ptr);

// Clean up
unsafe {
let managed = &mut *ptr;
if let Some(deleter) = managed.deleter.take() {
deleter(ptr);
println!("✓ Memory freed");
}
// Clean up using shared helper with safety checks
match unsafe { free_dlpack_tensor(ptr) } {
Ok(()) => println!("✓ Memory freed"),
Err(e) => eprintln!("✗ Failed to free DLPack tensor: {:?}", e),
}
}
Err(e) => {
Expand Down
11 changes: 6 additions & 5 deletions qdp/qdp-core/examples/observability_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
// Run: cargo run -p qdp-core --example observability_test --release

use qdp_core::QdpEngine;
use qdp_core::dlpack::free_dlpack_tensor;
use std::env;

fn main() {
Expand Down Expand Up @@ -92,12 +93,12 @@ fn main() {
for i in 0..NUM_SAMPLES {
let sample = &test_data[i * VECTOR_LEN..(i + 1) * VECTOR_LEN];
match engine.encode(sample, NUM_QUBITS, "amplitude") {
Ok(ptr) => unsafe {
let managed = &mut *ptr;
if let Some(deleter) = managed.deleter.take() {
deleter(ptr);
Ok(ptr) => {
if let Err(e) = unsafe { free_dlpack_tensor(ptr) } {
eprintln!("✗ Failed to free DLPack tensor for sample {}: {:?}", i, e);
return;
}
},
}
Err(e) => {
eprintln!("✗ Encoding failed for sample {}: {:?}", i, e);
return;
Expand Down
43 changes: 41 additions & 2 deletions qdp/qdp-core/src/dlpack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

// DLPack protocol for zero-copy GPU memory sharing with PyTorch

use crate::error::Result;
#[cfg(target_os = "linux")]
use crate::error::{MahoutError, cuda_error_to_string};
use crate::error::cuda_error_to_string;
use crate::error::{MahoutError, Result};
use crate::gpu::memory::{BufferStorage, GpuStateVector, Precision};
use std::os::raw::{c_int, c_void};
use std::sync::Arc;
Expand Down Expand Up @@ -205,6 +205,45 @@ pub unsafe extern "C" fn dlpack_deleter(managed: *mut DLManagedTensor) {
let _ = Box::from_raw(managed);
}

/// Safely free a `DLManagedTensor` pointer returned from encoding APIs.
///
/// This helper function centralizes the unsafe pointer dereference and deleter
/// invocation logic, adding safety checks to prevent common errors like null
/// pointer dereference and double-free.
///
/// # Safety
/// The caller must ensure:
/// - `ptr` is a valid `DLManagedTensor` pointer returned from `QdpEngine::encode()`
/// or similar methods, or is null
/// - The pointer has not been freed before (either by calling this function
/// or by PyTorch's DLPack deleter)
/// - The pointer is not used after this call
///
/// # Errors
/// Returns `Err` if:
/// - The pointer is null
/// - The deleter is missing or has already been called
#[allow(unsafe_op_in_unsafe_fn)]
pub unsafe fn free_dlpack_tensor(ptr: *mut DLManagedTensor) -> Result<()> {
if ptr.is_null() {
return Err(MahoutError::InvalidInput(
"DLPack pointer is null (nothing to free)".into(),
));
}

// SAFETY: Caller guarantees ptr is valid and not yet freed.
// We've checked it's not null above.
let managed = &mut *ptr;

let deleter = managed.deleter.take().ok_or_else(|| {
MahoutError::InvalidInput("DLPack deleter missing or already called".into())
})?;

// Call the DLPack deleter to free memory
deleter(ptr);
Ok(())
}

impl GpuStateVector {
/// Convert to DLPack format for PyTorch
///
Expand Down
132 changes: 131 additions & 1 deletion qdp/qdp-core/tests/dlpack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ mod dlpack_tests {
use std::ffi::c_void;

use cudarc::driver::CudaDevice;
use qdp_core::MahoutError;
use qdp_core::Precision;
use qdp_core::dlpack::{CUDA_STREAM_LEGACY, synchronize_stream};
use qdp_core::dlpack::{
CUDA_STREAM_LEGACY, DL_FLOAT, DLDataType, DLDevice, DLDeviceType, DLManagedTensor,
DLTensor, free_dlpack_tensor, synchronize_stream,
};
use qdp_core::gpu::memory::GpuStateVector;

#[test]
Expand Down Expand Up @@ -144,4 +148,130 @@ mod dlpack_tests {
);
}
}

/// free_dlpack_tensor(null) should return an InvalidInput error instead of panicking.
#[test]
fn test_free_dlpack_tensor_null_ptr() {
unsafe {
let result = free_dlpack_tensor(std::ptr::null_mut());
match result {
Err(MahoutError::InvalidInput(msg)) => {
assert!(
msg.to_lowercase().contains("null"),
"Expected null-pointer error message, got: {}",
msg
);
}
other => panic!(
"Expected InvalidInput error for null pointer, got: {:?}",
other
),
}
}
}

/// free_dlpack_tensor should detect missing deleter and return an InvalidInput error.
#[test]
fn test_free_dlpack_tensor_missing_deleter() {
// Minimal, but structurally valid, DLTensor for constructing DLManagedTensor.
let dummy_tensor = DLTensor {
data: std::ptr::null_mut(),
device: DLDevice {
device_type: DLDeviceType::kDLCPU,
device_id: 0,
},
ndim: 0,
dtype: DLDataType {
code: DL_FLOAT,
bits: 64,
lanes: 1,
},
shape: std::ptr::null_mut(),
strides: std::ptr::null_mut(),
byte_offset: 0,
};

let managed = DLManagedTensor {
dl_tensor: dummy_tensor,
manager_ctx: std::ptr::null_mut(),
deleter: None,
};

let ptr = Box::into_raw(Box::new(managed));

unsafe {
let result = free_dlpack_tensor(ptr);
match result {
Err(MahoutError::InvalidInput(msg)) => {
assert!(
msg.to_lowercase().contains("deleter"),
"Expected missing-deleter error message, got: {}",
msg
);
}
other => panic!(
"Expected InvalidInput error for missing deleter, got: {:?}",
other
),
}

// free_dlpack_tensor must not free the tensor when deleter is missing;
// reclaim it here to avoid a leak in tests.
let _ = Box::from_raw(ptr);
}
}

/// free_dlpack_tensor should call the deleter exactly once and return Ok(()).
#[test]
fn test_free_dlpack_tensor_calls_deleter() {
static mut DELETER_CALLED: bool = false;

unsafe extern "C" fn test_deleter(_ptr: *mut DLManagedTensor) {
// SAFETY: This test is single-threaded; it's safe to mutate the static flag.
unsafe {
DELETER_CALLED = true;
}
}

let dummy_tensor = DLTensor {
data: std::ptr::null_mut(),
device: DLDevice {
device_type: DLDeviceType::kDLCPU,
device_id: 0,
},
ndim: 0,
dtype: DLDataType {
code: DL_FLOAT,
bits: 64,
lanes: 1,
},
shape: std::ptr::null_mut(),
strides: std::ptr::null_mut(),
byte_offset: 0,
};

let managed = DLManagedTensor {
dl_tensor: dummy_tensor,
manager_ctx: std::ptr::null_mut(),
deleter: Some(test_deleter),
};

let ptr = Box::into_raw(Box::new(managed));

unsafe {
let result = free_dlpack_tensor(ptr);
assert!(
result.is_ok(),
"free_dlpack_tensor should succeed for valid pointer: {:?}",
result
);
assert!(
DELETER_CALLED,
"free_dlpack_tensor should invoke the DLPack deleter"
);

// Our custom deleter doesn't free the allocation; reclaim it here.
let _ = Box::from_raw(ptr);
}
}
}