-
-
Notifications
You must be signed in to change notification settings - Fork 152
Add JAX-specific rewrites and use static graph analysis to raise earlyΒ #1259
Description
As the refactor of JAX Scan dispatcher progresses, more and more branching logic is being added throughout the dispatcher. All of it is linked to JAX's restrictions in terms of shapes and slicing. All the graphs I have encountered so far describe computations that can be implemented in JAX, but it is sometimes not obvious based on the graph Aesara produces.
Since we are lowering to JAX from a static computation graph I think we should use the graph to take a defensive approach. The idea is to move the complexity that currently lives in the dispatcher upstream, and raise as early as possible in the transpilation process if necessary:
- Analyse the graph, and raise if we find a pattern that is not admissible for JAX;
- If no such patterns are found then canonicalize the graph in a JAX-friendly format so that it can be lowered with the simplest dispatcher possible.
- In addition, there are known "tricks" to make JAX programs compile. They often involve grouping computations, and rewrites + custom JAX-specific types and dispatch would automate these tricks. Effectively making the Aesara JAX backend easier to use than JAX itself.
Since JAX follows XLA very closely in its semantics and inherits its constraints, this line of work will be useful when we start targeting XLA directly. More generally, using static graph analysis, backend-specific types and graph rewrites to tailor the graph to each backend's specificities is generally a better approach than writing very complex dispatchers.