Skip to content

Commit d9701e0

Browse files
MNT: drop device= in cov weights
1 parent 0c57e2b commit d9701e0

2 files changed

Lines changed: 4 additions & 8 deletions

File tree

src/array_api_extra/_delegation.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,13 +230,10 @@ def cov(
230230

231231
if m.ndim <= 2 and integer_correction:
232232
if is_torch_namespace(xp):
233-
device = get_device(m)
234233
fw = (
235-
None
236-
if frequency_weights is None
237-
else xp.asarray(frequency_weights, device=device)
234+
None if frequency_weights is None else xp.asarray(frequency_weights)
238235
)
239-
aw = None if weights is None else xp.asarray(weights, device=device)
236+
aw = None if weights is None else xp.asarray(weights)
240237
return xp.cov(m, correction=int(correction), fweights=fw, aweights=aw)
241238
# `dask.array.cov` forces `.compute()` whenever weights are given:
242239
# its internal `if fact <= 0` check on a lazy 0-D scalar triggers

src/array_api_extra/_lib/_funcs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,16 +299,15 @@ def cov(
299299
m = atleast_nd(m, ndim=2, xp=xp)
300300
m = xp.astype(m, dtype)
301301

302-
device = _compat.device(m)
303302
fw = (
304303
None
305304
if frequency_weights is None
306-
else xp.astype(xp.asarray(frequency_weights, device=device), dtype)
305+
else xp.astype(xp.asarray(frequency_weights), dtype)
307306
)
308307
aw = (
309308
None
310309
if weights is None
311-
else xp.astype(xp.asarray(weights, device=device), dtype)
310+
else xp.astype(xp.asarray(weights), dtype)
312311
)
313312
if fw is None and aw is None:
314313
w = None

0 commit comments

Comments
 (0)