Skip to content

Commit eda11a0

Browse files
committed
Merge branch 'main' into sparse016
2 parents 09fff82 + ca55dc9 commit eda11a0

1 file changed

Lines changed: 6 additions & 19 deletions

File tree

tests/test_funcs.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import contextlib
21
import math
32
import warnings
43
from types import ModuleType
@@ -24,7 +23,7 @@
2423
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
2524
from array_api_extra._lib._utils._compat import device as get_device
2625
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
27-
from array_api_extra._lib._utils._typing import Array, Device
26+
from array_api_extra._lib._utils._typing import Device
2827
from array_api_extra.testing import lazy_xp_function
2928

3029
# some xp backends are untyped
@@ -287,23 +286,12 @@ def test_xp(self, xp: ModuleType):
287286

288287

289288
class TestExpandDims:
290-
@pytest.mark.xfail_xp_backend(Backend.DASK, reason="tuple index out of range")
291-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="tuple index out of range")
292-
@pytest.mark.xfail_xp_backend(Backend.TORCH, reason="tuple index out of range")
293-
def test_functionality(self, xp: ModuleType):
294-
def _squeeze_all(b: Array) -> Array:
295-
"""Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
296-
for axis in range(b.ndim):
297-
with contextlib.suppress(ValueError):
298-
b = xp.squeeze(b, axis=axis)
299-
return b
300-
301-
s = (2, 3, 4, 5)
302-
a = xp.empty(s)
289+
def test_single_axis(self, xp: ModuleType):
290+
"""Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims"""
291+
a = xp.empty((2, 3, 4, 5))
303292
for axis in range(-5, 4):
304293
b = expand_dims(a, axis=axis)
305-
assert b.shape[axis] == 1
306-
assert _squeeze_all(b).shape == s
294+
xp_assert_equal(b, xp.expand_dims(a, axis=axis))
307295

308296
def test_axis_tuple(self, xp: ModuleType):
309297
a = xp.empty((3, 3, 3))
@@ -313,8 +301,7 @@ def test_axis_tuple(self, xp: ModuleType):
313301
assert expand_dims(a, axis=(0, -3, -5)).shape == (1, 1, 3, 1, 3, 3)
314302

315303
def test_axis_out_of_range(self, xp: ModuleType):
316-
s = (2, 3, 4, 5)
317-
a = xp.empty(s)
304+
a = xp.empty((2, 3, 4, 5))
318305
with pytest.raises(IndexError, match="out of bounds"):
319306
_ = expand_dims(a, axis=-6)
320307
with pytest.raises(IndexError, match="out of bounds"):

0 commit comments

Comments
 (0)