fix(gh-2036): MyPy Errors in Distributions Module#2050
fix(gh-2036): MyPy Errors in Distributions Module#2050Qazalbash wants to merge 3 commits intopyro-ppl:masterfrom
Conversation
fehiepsi
left a comment
There was a problem hiding this comment.
Thanks @Qazalbash! I thought we can merge your changes and iterate on later PRs. Turns out that there are many non-trivial changes so it would be nice to split this into several PRs, each for one script like continuous.py etc. (would prefer simple ones first).
Let's resolve the typing issues when you have bandwidth. Thanks again for this important contribution!
|
|
||
| Message: TypeAlias = dict[str, Any] | ||
| TraceT: TypeAlias = OrderedDict[str, Message] | ||
| PRNGKeyT: TypeAlias = Union[jax.dtypes.prng_key, ArrayLike] |
There was a problem hiding this comment.
I think we can just use jax.Array for PRNGKey
| from jax import Array | ||
| from jax.typing import ArrayLike | ||
|
|
||
| from numpyro.distributions import MaskedDistribution |
There was a problem hiding this comment.
would prefer not importing MaskedDistribution here
| "numpyro.contrib.hsgp.*", | ||
| "numpyro.contrib.stochastic_support.*", | ||
| "numpyro.diagnostics.*", | ||
| "numpyro.distributions.*", |
| self, | ||
| fn: Optional[Callable] = None, | ||
| scale: ArrayLike = 1.0, | ||
| scale: Array = 1.0, |
There was a problem hiding this comment.
note: ArrayLike works for float
| self, | ||
| fn: Optional[Callable] = None, | ||
| mask: Optional[ArrayLike] = True, | ||
| mask: Optional[Array] = True, |
There was a problem hiding this comment.
note: ArrayLike works for boolean
|
|
||
| @singledispatch | ||
| def vmap_over(d: Union[Distribution, Transform, Constraint], **kwargs): | ||
| def vmap_over(d: DistributionT, **kwargs): |
Thanks, I will make sure these issues get resolved as early as possible. |
|
@Qazalbash shall we close this one as we are tackling it bit by bit on smaller PRs? :) |
|
Yes, please, it serves no purpose. Also, let me know which module we should tackle next. |
This PR contains the partial resolution of mypy errors passed by #2032.