Skip to content

Commit e30641e

Browse files
committed
Convert eagerly to XTensor to avoid issue with DataArrays
Binary operations with DataArrays on the left try to access XTensorVariable .coords raising NotImplementedError.
1 parent 421b915 commit e30641e

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

pymc/dims/distributions/scalar.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,8 @@ def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs):
247247
if (alpha is not None) and (beta is not None):
248248
pass
249249
elif (mu is not None) and (sigma is not None):
250+
mu = as_xtensor(mu)
251+
sigma = as_xtensor(sigma)
250252
# Use sign of sigma to not let negative sigma fly by
251253
alpha = (mu**2 / sigma**2) * ptx.math.sign(sigma)
252254
beta = mu / sigma**2
@@ -269,6 +271,8 @@ def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs):
269271
beta = 1.0
270272
elif (mu is not None) and (sigma is not None):
271273
# Use sign of sigma to not let negative sigma fly by
274+
mu = as_xtensor(mu)
275+
sigma = as_xtensor(sigma)
272276
alpha = ((2 * sigma**2 + mu**2) / sigma**2) * ptx.math.sign(sigma)
273277
beta = mu * (mu**2 + sigma**2) / sigma**2
274278
else:

0 commit comments

Comments
 (0)