ENH: add diag_indices, tril_indices, triu_indices#692
ENH: add diag_indices, tril_indices, triu_indices#692bruAristimunha wants to merge 2 commits intodata-apis:mainfrom
diag_indices, tril_indices, triu_indices#692Conversation
Resolves data-apis#686. Adds the three index-generating functions that numpy, jax, and cupy all have but that are missing from the array-api standard and (so far) from this library. Signatures follow array-api conventions: parameter `offset` (matching `xp.linalg.diagonal`) instead of numpy's `k`; keyword-only arguments for everything except `n`; `xp` is required (these functions have no input array to infer from, following the `default_dtype` precedent). Delegation: - numpy/cupy/jax: forward directly (signatures match verbatim). - dask: has tril/triu_indices but no diag_indices. - torch: has tril/triu_indices but with (row, col, *, offset) signature returning a 2xN tensor rather than a tuple; delegation translates. No torch.diag_indices exists; falls through to generic. - sparse, array-api-strictest: fall through to generic; marked xfail on those backends (no nonzero / data-dependent shapes). Generic implementation uses `xp.arange` + broadcasting + `xp.nonzero` for the triangle variants. Validation (n >= 0, ndim >= 1, m >= 0) happens in the delegation layer so all backends produce consistent ValueErrors. Also fixes a pre-existing bug in tests/conftest.py's NumPyReadOnly wrapper: `type(o)(*gen)` worked for namedtuples but failed for plain tuples of length >= 2. Exposed here because these are the first functions in the library that return a tuple of arrays.
e905c26 to
2654568
Compare
|
hey @lucascolley! I was wondering, if you have some time, could you please review this PR? I would be happy to address and review any comments that you have. I am open to address in any macro way to make the revision simpler (put more comments, get the reference and etc) |
|
hey @bruAristimunha! I'm pretty busy right now but I'll try to take a look at this soon, thanks |
diag_indices, tril_indices, triu_indices
Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
| def tril_indices( | ||
| n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType | ||
| ) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01 | ||
| """See docstring in array_api_extra._delegation.""" | ||
| return _tri_indices(n, offset=offset, m=m, upper=False, xp=xp) | ||
|
|
||
|
|
||
| def triu_indices( | ||
| n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType | ||
| ) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01 | ||
| """See docstring in array_api_extra._delegation.""" | ||
| return _tri_indices(n, offset=offset, m=m, upper=True, xp=xp) |
There was a problem hiding this comment.
somewhat of a side-comment: what is the general feeling on boolean upper parameters nowadays @ev-br ? I wonder whether we should consider exposing a single function with an upper parameter rather than two separate functions.
lucascolley
left a comment
There was a problem hiding this comment.
thanks @bruAristimunha, looks great! I haven't taken a detailed look yet, but a few comments for now.
- please could you update https://github.com/data-apis/array-api-extra/blob/main/docs/api-reference.md ?
- please could you also add
deviceandxptests, e.g.array-api-extra/tests/test_funcs.py
Lines 764 to 771 in 2690f22
- please could you add a note to the Tri function docstrings that the generic implementation requires
nonzeroto be implemented.
Resolves #686.
Summary
diag_indices,tril_indices,triu_indices— three index-generating functions that numpy, jax, and cupy all have but that are missing from the array-api standard and from this library.offset(matchingxp.linalg.diagonal) instead of numpy'sk; keyword-only arguments for everything exceptn;xpis required (these functions have no input array to infer from, following thedefault_dtypeprecedent).Numpy migration
Delegation
tril_indices/triu_indicesbut nodiag_indices— the last one falls through to the generic impl.tril_indices/triu_indicesbut with(row, col, *, offset)signature returning a 2×N tensor rather than a tuple; delegation translates. Notorch.diag_indicesexists.nonzero/ data-dependent shapes).Validation (
n >= 0,ndim >= 1,m >= 0) happens in the delegation layer via a shared_check_nonneghelper, so all backends emit consistentValueErrors before any backend-specific code runs.Generic implementation
diag_indices→(xp.arange(n),) * ndim.tril_indices/triu_indices→ shared_tri_indiceshelper:xp.arange+ broadcasting +xp.nonzeroon the mask. Pure array-api, fully lazy on dask.Also in this PR
Fixes a pre-existing bug in
tests/conftest.py'sNumPyReadOnlywrapper:type(o)(*gen)worked for namedtuples but failed for plain tuples of length ≥ 2. Exposed here because these are the first functions in the library that return a tuple of arrays.Test plan
pytest tests/test_funcs.py::TestDiagIndices tests/test_funcs.py::TestTriIndices— 155 passed across numpy, torch, jax, dask, array-api-strict (+ xfail on sparse/strictest/dask-use_to_read where noted).pytest tests/test_funcs.pyfull — all passing.lefthook run pre-commit --all-files— ruff, numpydoc, mypy, pyright, blacken-docs, validate-pyproject, dprint, typos all green.lazy_xp_function(tril_indices)/lazy_xp_function(triu_indices)assert 0.compute()calls, holds for both native and generic paths.