1- import contextlib
21import math
32import warnings
43from types import ModuleType
2423from array_api_extra ._lib ._testing import xp_assert_close , xp_assert_equal
2524from array_api_extra ._lib ._utils ._compat import device as get_device
2625from array_api_extra ._lib ._utils ._helpers import eager_shape , ndindex
27- from array_api_extra ._lib ._utils ._typing import Array , Device
26+ from array_api_extra ._lib ._utils ._typing import Device
2827from array_api_extra .testing import lazy_xp_function
2928
3029# some xp backends are untyped
@@ -287,23 +286,12 @@ def test_xp(self, xp: ModuleType):
287286
288287
289288class TestExpandDims :
290- @pytest .mark .xfail_xp_backend (Backend .DASK , reason = "tuple index out of range" )
291- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "tuple index out of range" )
292- @pytest .mark .xfail_xp_backend (Backend .TORCH , reason = "tuple index out of range" )
293- def test_functionality (self , xp : ModuleType ):
294- def _squeeze_all (b : Array ) -> Array :
295- """Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
296- for axis in range (b .ndim ):
297- with contextlib .suppress (ValueError ):
298- b = xp .squeeze (b , axis = axis )
299- return b
300-
301- s = (2 , 3 , 4 , 5 )
302- a = xp .empty (s )
289+ def test_single_axis (self , xp : ModuleType ):
290+ """Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims"""
291+ a = xp .empty ((2 , 3 , 4 , 5 ))
303292 for axis in range (- 5 , 4 ):
304293 b = expand_dims (a , axis = axis )
305- assert b .shape [axis ] == 1
306- assert _squeeze_all (b ).shape == s
294+ xp_assert_equal (b , xp .expand_dims (a , axis = axis ))
307295
308296 def test_axis_tuple (self , xp : ModuleType ):
309297 a = xp .empty ((3 , 3 , 3 ))
@@ -313,8 +301,7 @@ def test_axis_tuple(self, xp: ModuleType):
313301 assert expand_dims (a , axis = (0 , - 3 , - 5 )).shape == (1 , 1 , 3 , 1 , 3 , 3 )
314302
315303 def test_axis_out_of_range (self , xp : ModuleType ):
316- s = (2 , 3 , 4 , 5 )
317- a = xp .empty (s )
304+ a = xp .empty ((2 , 3 , 4 , 5 ))
318305 with pytest .raises (IndexError , match = "out of bounds" ):
319306 _ = expand_dims (a , axis = - 6 )
320307 with pytest .raises (IndexError , match = "out of bounds" ):
0 commit comments