Skip to content

Commit 21b0931

Browse files
committed
ENH: add linalg.eig, linalg.eigvals
1 parent c303adc commit 21b0931

File tree

1 file changed

+62
-1
lines changed

1 file changed

+62
-1
lines changed

array_api_strict/_linalg.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ._data_type_functions import finfo
1010
from ._dtypes import DType, _floating_dtypes, _numeric_dtypes, complex64, complex128
1111
from ._elementwise_functions import conj
12-
from ._flags import get_array_api_strict_flags, requires_extension
12+
from ._flags import get_array_api_strict_flags, requires_extension, requires_api_version
1313
from ._manipulation_functions import reshape
1414
from ._statistical_functions import _np_dtype_sumprod
1515

@@ -23,6 +23,10 @@ class EighResult(NamedTuple):
2323
eigenvalues: Array
2424
eigenvectors: Array
2525

26+
class EigResult(NamedTuple):
27+
eigenvalues: Array
28+
eigenvectors: Array
29+
2630
class QRResult(NamedTuple):
2731
Q: Array
2832
R: Array
@@ -144,6 +148,63 @@ def eigvalsh(x: Array, /) -> Array:
144148

145149
return Array._new(np.linalg.eigvalsh(x._array), device=x.device)
146150

151+
@requires_extension('linalg')
152+
@requires_api_version('2025.12')
153+
def eigvals(x: Array, /) -> Array:
154+
"""
155+
Array API compatible wrapper for :py:func:`np.linalg.eigvals <numpy.linalg.eigvals>`.
156+
157+
See its docstring for more information.
158+
"""
159+
# Note: the restriction to floating-point dtypes only is different from
160+
# np.linalg.eigvals.
161+
if x.dtype not in _floating_dtypes:
162+
raise TypeError('Only floating-point dtypes are allowed in eigvals')
163+
164+
res = np.linalg.eigvals(x._array)
165+
166+
# numpy return reals for real inputs
167+
res_dtype = res.dtype
168+
if res.dtype == np.float32:
169+
res_dtype = np.complex64
170+
elif res.dtype == np.float64:
171+
res_dtype = np.complex128
172+
173+
if res_dtype != res.dtype:
174+
res = res.astype(res_dtype)
175+
176+
return Array._new(res, device=x.device)
177+
178+
179+
@requires_extension('linalg')
180+
@requires_api_version('2025.12')
181+
def eig(x: Array, /) -> EigResult:
182+
"""
183+
Array API compatible wrapper for :py:func:`np.linalg.eig <numpy.linalg.eig>`.
184+
185+
See its docstring for more information.
186+
"""
187+
# Note: the restriction to floating-point dtypes only is different from
188+
# np.linalg.eig.
189+
if x.dtype not in _floating_dtypes:
190+
raise TypeError('Only floating-point dtypes are allowed in eig')
191+
192+
w, vr = np.linalg.eig(x._array)
193+
194+
# numpy return reals for real inputs
195+
res_dtype = w.dtype
196+
if w.dtype == np.float32:
197+
res_dtype = np.complex64
198+
elif w.dtype == np.float64:
199+
res_dtype = np.complex128
200+
201+
if res_dtype != w.dtype:
202+
w = w.astype(res_dtype)
203+
vr = vr.astype(res_dtype)
204+
205+
return EigResult(Array._new(w, device=x.device), Array._new(vr, device=x.device))
206+
207+
147208
@requires_extension('linalg')
148209
def inv(x: Array, /) -> Array:
149210
"""

0 commit comments

Comments
 (0)