Skip to content

Commit d7366a1

Browse files
committed
add unit test
Signed-off-by: 400Ping <jiekaichang@apache.org>
1 parent 66bc828 commit d7366a1

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

qdp/qdp-core/tests/dlpack.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818

1919
#[cfg(test)]
2020
mod dlpack_tests {
21+
use std::ffi::c_void;
22+
2123
use cudarc::driver::CudaDevice;
24+
use qdp_core::dlpack::{synchronize_stream, CUDA_STREAM_LEGACY};
2225
use qdp_core::gpu::memory::GpuStateVector;
2326

2427
#[test]
@@ -82,4 +85,29 @@ mod dlpack_tests {
8285
}
8386
}
8487
}
88+
89+
/// synchronize_stream(null) is a no-op and returns Ok(()) on all platforms.
90+
#[test]
91+
fn test_synchronize_stream_null() {
92+
unsafe {
93+
let result = synchronize_stream(std::ptr::null_mut::<c_void>());
94+
assert!(
95+
result.is_ok(),
96+
"synchronize_stream(null) should return Ok(())"
97+
);
98+
}
99+
}
100+
101+
/// synchronize_stream(CUDA_STREAM_LEGACY) syncs the legacy default stream (Linux + CUDA).
102+
#[test]
103+
#[cfg(target_os = "linux")]
104+
fn test_synchronize_stream_legacy() {
105+
unsafe {
106+
let result = synchronize_stream(CUDA_STREAM_LEGACY);
107+
assert!(
108+
result.is_ok(),
109+
"synchronize_stream(CUDA_STREAM_LEGACY) should succeed on Linux with CUDA"
110+
);
111+
}
112+
}
85113
}

testing/qdp/test_bindings.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,25 @@ def test_dlpack_single_use():
125125
qtensor2.__dlpack__()
126126

127127

128+
@requires_qdp
129+
@pytest.mark.gpu
130+
@pytest.mark.parametrize("stream", [1, 2], ids=["stream_legacy", "stream_per_thread"])
131+
def test_dlpack_with_stream(stream):
132+
"""Test __dlpack__(stream=...) syncs CUDA stream before returning capsule (DLPack 0.8+)."""
133+
import torch
134+
from _qdp import QdpEngine
135+
136+
engine = QdpEngine(0)
137+
data = [1.0, 2.0, 3.0, 4.0]
138+
qtensor = engine.encode(data, 2, "amplitude")
139+
140+
# stream=1 (legacy default) or 2 (per-thread default) should sync and return capsule
141+
capsule = qtensor.__dlpack__(stream=stream)
142+
torch_tensor = torch.from_dlpack(capsule)
143+
assert torch_tensor.is_cuda
144+
assert torch_tensor.shape == (1, 4)
145+
146+
128147
@requires_qdp
129148
@pytest.mark.gpu
130149
def test_pytorch_integration():

0 commit comments

Comments
 (0)