@@ -122,8 +122,8 @@ def test_copy_with_dtype_valid(dtype, element_size, usm_type):
122122 host_buf = bytearray (i % 256 for i in range (nbytes ))
123123 src .copy_from_host (host_buf )
124124
125- # Copy device to device with dtype validation
126- q .copy (dst , src , nbytes , dtype = dtype )
125+ # Copy device to device, count given in elements of dtype
126+ q .copy (dst , src , num_elements , dtype = dtype )
127127
128128 # Verify via host buffer
129129 result_buf = bytearray (nbytes )
@@ -135,8 +135,8 @@ def test_copy_with_dtype_valid(dtype, element_size, usm_type):
135135 for i in range (nbytes ):
136136 src_mv [i ] = i % 256
137137
138- # Copy with dtype validation
139- q .copy (dst , src , nbytes , dtype = dtype )
138+ # Copy, count given in elements of dtype
139+ q .copy (dst , src , num_elements , dtype = dtype )
140140
141141 # Verify
142142 dst_mv = memoryview (dst )
@@ -163,8 +163,8 @@ def test_copy_async_with_dtype_valid():
163163 for i in range (nbytes ):
164164 src_mv [i ] = i % 256
165165
166- # Async copy with dtype
167- e = q .copy_async (dst , src , nbytes , dtype = dtype )
166+ # Async copy with dtype, count given in elements
167+ e = q .copy_async (dst , src , num_elements , dtype = dtype )
168168 e .wait ()
169169
170170 # Verify
@@ -173,47 +173,38 @@ def test_copy_async_with_dtype_valid():
173173
174174
175175@pytest .mark .parametrize (
176- "dtype,element_size,bad_count " ,
176+ "dtype,element_size" ,
177177 [
178- ("i4 " , 4 , 7 ), # 7 is not divisible by 4
179- ("f8 " , 8 , 13 ),
180- ("i2 " , 2 , 5 ),
181- ("u8" , 8 , 3 ),
178+ ("i2 " , 2 ),
179+ ("i4 " , 4 ),
180+ ("f8 " , 8 ),
181+ ("u8" , 8 ),
182182 ],
183183)
184- def test_copy_with_dtype_invalid_count (dtype , element_size , bad_count ):
185- """Test copy raises ValueError when count isn't a dtype multiple ."""
184+ def test_copy_count_is_in_elements (dtype , element_size ):
185+ """``count`` is interpreted as a number of elements of ``dtype`` ."""
186186 try :
187187 q = dpctl .SyclQueue ()
188188 except dpctl .SyclQueueCreationError :
189189 pytest .skip ("Default constructor for SyclQueue failed" )
190190
191- nbytes = 64
191+ num_elements = 4
192+ nbytes = num_elements * element_size
193+
192194 src = dpctl .memory .MemoryUSMShared (nbytes , queue = q )
193195 dst = dpctl .memory .MemoryUSMShared (nbytes , queue = q )
194196
195- with pytest .raises (ValueError ) as cm :
196- q .copy (dst , src , bad_count , dtype = dtype )
197- assert "multiple" in str (cm .value ).lower ()
198- assert str (element_size ) in str (cm .value )
199-
200-
201- def test_copy_async_with_dtype_invalid_count ():
202- """Test that copy_async raises ValueError for invalid count."""
203- try :
204- q = dpctl .SyclQueue ()
205- except dpctl .SyclQueueCreationError :
206- pytest .skip ("Default constructor for SyclQueue failed" )
197+ src_mv = memoryview (src )
198+ for i in range (nbytes ):
199+ src_mv [i ] = i % 256
200+ dst_mv = memoryview (dst )
207201
208- dtype = "i4"
209- bad_count = 7 # Not divisible by 4
210- nbytes = 64
211- src = dpctl .memory .MemoryUSMShared (nbytes , queue = q )
212- dst = dpctl .memory .MemoryUSMShared (nbytes , queue = q )
202+ # Copying half the elements transfers exactly half the bytes.
203+ q .copy (dst , src , num_elements // 2 , dtype = dtype )
213204
214- with pytest . raises ( ValueError ) as cm :
215- q . copy_async ( dst , src , bad_count , dtype = dtype )
216- assert "multiple" in str ( cm . value ). lower ( )
205+ half_bytes = ( num_elements // 2 ) * element_size
206+ assert dst_mv [: half_bytes ]. tobytes () == src_mv [: half_bytes ]. tobytes ( )
207+ assert dst_mv [ half_bytes : nbytes ]. tobytes () == bytes ( nbytes - half_bytes )
217208
218209
219210def test_copy_with_invalid_dtype ():
@@ -247,13 +238,11 @@ def test_copy_with_dtype_host_buffers():
247238
248239 dtype = "f4"
249240 num_elements = 20
250- element_size = 4
251- nbytes = num_elements * element_size
252241
253242 src = np .arange (num_elements , dtype = np .float32 )
254243 dst = np .zeros (num_elements , dtype = np .float32 )
255244
256- q .copy (dst , src , nbytes , dtype = dtype )
245+ q .copy (dst , src , num_elements , dtype = dtype )
257246
258247 assert np .array_equal (dst , src )
259248
@@ -267,18 +256,17 @@ def test_copy_with_dtype_mixed_sources():
267256
268257 dtype = "i8"
269258 num_elements = 10
270- element_size = 8
271- nbytes = num_elements * element_size
259+ nbytes = num_elements * 8
272260
273261 # Host to USM
274262 src_host = np .arange (num_elements , dtype = np .int64 )
275263 dst_usm = dpctl .memory .MemoryUSMShared (nbytes , queue = q )
276264
277- q .copy (dst_usm , src_host , nbytes , dtype = dtype )
265+ q .copy (dst_usm , src_host , num_elements , dtype = dtype )
278266
279267 # USM to host
280268 dst_host = np .zeros (num_elements , dtype = np .int64 )
281- q .copy (dst_host , dst_usm , nbytes , dtype = dtype )
269+ q .copy (dst_host , dst_usm , num_elements , dtype = dtype )
282270
283271 assert np .array_equal (dst_host , src_host )
284272
0 commit comments