Skip to content

Commit e956642

Browse files
committed
fix: more review
1 parent 0e4de9f commit e956642

3 files changed

Lines changed: 59 additions & 70 deletions

File tree

dpctl/_sycl_queue.pyx

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -584,20 +584,14 @@ cdef DPCTLSyclEventRef _copy_impl(
584584
SyclQueue q,
585585
object dst,
586586
object src,
587-
size_t byte_count,
587+
size_t count,
588588
DPCTLSyclEventRef *dep_events,
589589
size_t dep_events_count,
590-
str dtype=None
590+
str dtype="u1"
591591
) except *:
592-
cdef size_t element_size = 0
593-
594-
if dtype is not None:
595-
element_size = _get_dtype_size(dtype)
596-
if byte_count % element_size != 0:
597-
raise ValueError(
598-
f"byte_count ({byte_count}) must be a multiple of dtype "
599-
f"element size ({element_size} bytes for '{dtype}')"
600-
)
592+
# ``count`` is in elements of ``dtype`` (default "u1" => bytes).
593+
cdef size_t element_size = _get_dtype_size(dtype)
594+
cdef size_t byte_count = count * element_size
601595

602596
return _copy_memcpy_impl(
603597
q, dst, src, byte_count, dep_events, dep_events_count,
@@ -1495,11 +1489,14 @@ cdef class SyclQueue(_SyclQueue):
14951489

14961490
return SyclEvent._create(ERef)
14971491

1498-
cpdef copy(self, dest, src, size_t count, str dtype=None):
1499-
"""Copy ``count`` bytes from ``src`` to ``dest`` and wait.
1492+
cpdef copy(self, dest, src, size_t count, str dtype="u1"):
1493+
"""Copy ``count`` elements of type ``dtype`` from ``src`` to
1494+
``dest`` and wait.
15001495
1501-
Internally, this dispatches ``sycl::queue::copy`` instantiated for
1502-
byte-sized elements (or typed elements if dtype is specified).
1496+
Internally, this dispatches ``sycl::queue::copy``. The number of
1497+
bytes transferred is ``count`` multiplied by the size of ``dtype``.
1498+
The default ``dtype`` of ``"u1"`` (a single byte) makes the default
1499+
a byte-wise copy.
15031500
15041501
This is a synchronizing variant corresponding to
15051502
:meth:`dpctl.SyclQueue.copy_async`.
@@ -1512,11 +1509,12 @@ cdef class SyclQueue(_SyclQueue):
15121509
Source USM object or Python object supporting buffer
15131510
protocol.
15141511
count (int):
1515-
Number of bytes to copy.
1512+
Number of elements to copy.
15161513
dtype (str, optional):
1517-
Data type string (e.g., 'i4', 'f8') for typed copy
1518-
validation. If provided, validates that count is a
1519-
multiple of the element size.
1514+
Data type string of the elements to copy. Determines the
1515+
element size used to convert ``count`` into a byte count.
1516+
Defaults to ``"u1"`` (one byte per element).
1517+
Supported types: i1, u1, i2, u2, i4, u4, i8, u8, f4, f8.
15201518
"""
15211519
cdef DPCTLSyclEventRef ERef = NULL
15221520

@@ -1530,12 +1528,15 @@ cdef class SyclQueue(_SyclQueue):
15301528
DPCTLEvent_Delete(ERef)
15311529

15321530
cpdef SyclEvent copy_async(
1533-
self, dest, src, size_t count, list dEvents=None, str dtype=None
1531+
self, dest, src, size_t count, list dEvents=None, str dtype="u1"
15341532
):
1535-
"""Copy ``count`` bytes from ``src`` to ``dest`` asynchronously.
1533+
"""Copy ``count`` elements of type ``dtype`` from ``src`` to
1534+
``dest`` asynchronously.
15361535
1537-
Internally, this dispatches ``sycl::queue::copy`` instantiated for
1538-
byte-sized elements (or typed elements if dtype is specified).
1536+
Internally, this dispatches ``sycl::queue::copy``. The number of
1537+
bytes transferred is ``count`` multiplied by the size of ``dtype``.
1538+
The default ``dtype`` of ``"u1"`` (a single byte) makes the default
1539+
a byte-wise copy.
15391540
15401541
Args:
15411542
dest:
@@ -1545,13 +1546,13 @@ cdef class SyclQueue(_SyclQueue):
15451546
Source USM object or Python object supporting buffer
15461547
protocol.
15471548
count (int):
1548-
Number of bytes to copy.
1549+
Number of elements to copy.
15491550
dEvents (List[dpctl.SyclEvent], optional):
15501551
Events that this copy depends on.
15511552
dtype (str, optional):
1552-
Data type string (e.g., 'i4', 'f8') for typed copy
1553-
validation. If provided, validates that count is a
1554-
multiple of the element size.
1553+
Data type string of the elements to copy. Determines the
1554+
element size used to convert ``count`` into a byte count.
1555+
Defaults to ``"u1"`` (one byte per element).
15551556
Supported types: i1, u1, i2, u2, i4, u4, i8, u8, f4, f8.
15561557
15571558
Returns:

dpctl/tests/test_sycl_queue_copy.py

Lines changed: 29 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def test_copy_with_dtype_valid(dtype, element_size, usm_type):
122122
host_buf = bytearray(i % 256 for i in range(nbytes))
123123
src.copy_from_host(host_buf)
124124

125-
# Copy device to device with dtype validation
126-
q.copy(dst, src, nbytes, dtype=dtype)
125+
# Copy device to device, count given in elements of dtype
126+
q.copy(dst, src, num_elements, dtype=dtype)
127127

128128
# Verify via host buffer
129129
result_buf = bytearray(nbytes)
@@ -135,8 +135,8 @@ def test_copy_with_dtype_valid(dtype, element_size, usm_type):
135135
for i in range(nbytes):
136136
src_mv[i] = i % 256
137137

138-
# Copy with dtype validation
139-
q.copy(dst, src, nbytes, dtype=dtype)
138+
# Copy, count given in elements of dtype
139+
q.copy(dst, src, num_elements, dtype=dtype)
140140

141141
# Verify
142142
dst_mv = memoryview(dst)
@@ -163,8 +163,8 @@ def test_copy_async_with_dtype_valid():
163163
for i in range(nbytes):
164164
src_mv[i] = i % 256
165165

166-
# Async copy with dtype
167-
e = q.copy_async(dst, src, nbytes, dtype=dtype)
166+
# Async copy with dtype, count given in elements
167+
e = q.copy_async(dst, src, num_elements, dtype=dtype)
168168
e.wait()
169169

170170
# Verify
@@ -173,47 +173,38 @@ def test_copy_async_with_dtype_valid():
173173

174174

175175
@pytest.mark.parametrize(
176-
"dtype,element_size,bad_count",
176+
"dtype,element_size",
177177
[
178-
("i4", 4, 7), # 7 is not divisible by 4
179-
("f8", 8, 13),
180-
("i2", 2, 5),
181-
("u8", 8, 3),
178+
("i2", 2),
179+
("i4", 4),
180+
("f8", 8),
181+
("u8", 8),
182182
],
183183
)
184-
def test_copy_with_dtype_invalid_count(dtype, element_size, bad_count):
185-
"""Test copy raises ValueError when count isn't a dtype multiple."""
184+
def test_copy_count_is_in_elements(dtype, element_size):
185+
"""``count`` is interpreted as a number of elements of ``dtype``."""
186186
try:
187187
q = dpctl.SyclQueue()
188188
except dpctl.SyclQueueCreationError:
189189
pytest.skip("Default constructor for SyclQueue failed")
190190

191-
nbytes = 64
191+
num_elements = 4
192+
nbytes = num_elements * element_size
193+
192194
src = dpctl.memory.MemoryUSMShared(nbytes, queue=q)
193195
dst = dpctl.memory.MemoryUSMShared(nbytes, queue=q)
194196

195-
with pytest.raises(ValueError) as cm:
196-
q.copy(dst, src, bad_count, dtype=dtype)
197-
assert "multiple" in str(cm.value).lower()
198-
assert str(element_size) in str(cm.value)
199-
200-
201-
def test_copy_async_with_dtype_invalid_count():
202-
"""Test that copy_async raises ValueError for invalid count."""
203-
try:
204-
q = dpctl.SyclQueue()
205-
except dpctl.SyclQueueCreationError:
206-
pytest.skip("Default constructor for SyclQueue failed")
197+
src_mv = memoryview(src)
198+
for i in range(nbytes):
199+
src_mv[i] = i % 256
200+
dst_mv = memoryview(dst)
207201

208-
dtype = "i4"
209-
bad_count = 7 # Not divisible by 4
210-
nbytes = 64
211-
src = dpctl.memory.MemoryUSMShared(nbytes, queue=q)
212-
dst = dpctl.memory.MemoryUSMShared(nbytes, queue=q)
202+
# Copying half the elements transfers exactly half the bytes.
203+
q.copy(dst, src, num_elements // 2, dtype=dtype)
213204

214-
with pytest.raises(ValueError) as cm:
215-
q.copy_async(dst, src, bad_count, dtype=dtype)
216-
assert "multiple" in str(cm.value).lower()
205+
half_bytes = (num_elements // 2) * element_size
206+
assert dst_mv[:half_bytes].tobytes() == src_mv[:half_bytes].tobytes()
207+
assert dst_mv[half_bytes:nbytes].tobytes() == bytes(nbytes - half_bytes)
217208

218209

219210
def test_copy_with_invalid_dtype():
@@ -247,13 +238,11 @@ def test_copy_with_dtype_host_buffers():
247238

248239
dtype = "f4"
249240
num_elements = 20
250-
element_size = 4
251-
nbytes = num_elements * element_size
252241

253242
src = np.arange(num_elements, dtype=np.float32)
254243
dst = np.zeros(num_elements, dtype=np.float32)
255244

256-
q.copy(dst, src, nbytes, dtype=dtype)
245+
q.copy(dst, src, num_elements, dtype=dtype)
257246

258247
assert np.array_equal(dst, src)
259248

@@ -267,18 +256,17 @@ def test_copy_with_dtype_mixed_sources():
267256

268257
dtype = "i8"
269258
num_elements = 10
270-
element_size = 8
271-
nbytes = num_elements * element_size
259+
nbytes = num_elements * 8
272260

273261
# Host to USM
274262
src_host = np.arange(num_elements, dtype=np.int64)
275263
dst_usm = dpctl.memory.MemoryUSMShared(nbytes, queue=q)
276264

277-
q.copy(dst_usm, src_host, nbytes, dtype=dtype)
265+
q.copy(dst_usm, src_host, num_elements, dtype=dtype)
278266

279267
# USM to host
280268
dst_host = np.zeros(num_elements, dtype=np.int64)
281-
q.copy(dst_host, dst_usm, nbytes, dtype=dtype)
269+
q.copy(dst_host, dst_usm, num_elements, dtype=dtype)
282270

283271
assert np.array_equal(dst_host, src_host)
284272

libsyclinterface/source/dpctl_sycl_queue_interface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ DPCTLQueue_CopyData(__dpctl_keep const DPCTLSyclQueueRef QRef,
704704
if (Q) {
705705
sycl::event ev;
706706
try {
707-
// Bind queue::copy with uint8_t so Count is interpreted as bytes.
707+
// Copy uint8_t elements (1 byte each), so Count is a byte count.
708708
ev = Q->copy(static_cast<const std::uint8_t *>(Src),
709709
static_cast<std::uint8_t *>(Dest), Count);
710710
} catch (std::exception const &e) {
@@ -741,7 +741,7 @@ DPCTLQueue_CopyDataWithEvents(__dpctl_keep const DPCTLSyclQueueRef QRef,
741741
}
742742
}
743743

744-
// Bind queue::copy with uint8_t so Count is interpreted as bytes.
744+
// Copy uint8_t elements (1 byte each), so Count is a byte count.
745745
auto ev =
746746
Q->copy(static_cast<const std::uint8_t *>(Src),
747747
static_cast<std::uint8_t *>(Dest), Count, dep_events);

0 commit comments

Comments
 (0)