-
Notifications
You must be signed in to change notification settings - Fork 29
Unexhaustive list of structural improvements #39
Description
desilike development started when jax wasn't so widespread. But desilike now heavily relies on JAX. I think we should bite the bullet and structure desilike in a more jax-friendly way:
-
all data structures, including
Calculator-derived classes, should be pytrees, following Equinox ideas. Then everything will go through standard jax jit's and vmap's nicely, and every JAX-based sampler. -
support for multi-dimensional parameters, and allow for
Samplesto be split row-wise (in case the parameter space is really large...). -
make the Sampler and Profiler API a bit nicer, if we just want to swap in a different sampler, and keep the export to Samples the same.
e.g. a functionmake_sampler(run=likelihood -> array of samples for varied parameters)would return aSampler -
for non-JAX functions, one could use callbacks
I'm not keen on relying on numpyro at the moment (our Parameter class works well), but we should certainly implement a to_numpyro() method.