Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions aesara/link/jax/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def jax_typify(data, dtype=None, **kwargs):

@jax_typify.register(np.ndarray)
def jax_typify_ndarray(data, dtype=None, **kwargs):
if len(data.shape) == 0:
return data.item()
return jnp.array(data, dtype=dtype)


Expand Down Expand Up @@ -82,26 +84,10 @@ def assert_fn(x, *inputs):
return assert_fn


def jnp_safe_copy(x):
try:
res = jnp.copy(x)
except NotImplementedError:
warnings.warn(
"`jnp.copy` is not implemented yet. " "Using the object's `copy` method."
)
if hasattr(x, "copy"):
res = jnp.array(x.copy())
else:
warnings.warn(f"Object has no `copy` method: {x}")
res = x

return res


@jax_funcify.register(DeepCopyOp)
def jax_funcify_DeepCopyOp(op, **kwargs):
def deepcopyop(x):
return jnp_safe_copy(x)
return jnp.copy(x)

return deepcopyop

Expand Down
4 changes: 2 additions & 2 deletions aesara/link/jax/dispatch/elemwise.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp

from aesara.link.jax.dispatch.basic import jax_funcify, jnp_safe_copy
from aesara.link.jax.dispatch.basic import jax_funcify
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad

Expand Down Expand Up @@ -69,7 +69,7 @@ def dimshuffle(x):
res = jnp.reshape(res, shape)

if not op.inplace:
res = jnp_safe_copy(res)
res = jnp.copy(res)

return res

Expand Down
103 changes: 101 additions & 2 deletions aesara/link/jax/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,57 @@

from aesara.link.jax.dispatch.basic import jax_funcify
from aesara.scalar import Softplus
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scalar.basic import (
Add,
Cast,
Clip,
Composite,
Identity,
IntDiv,
Mod,
Mul,
ScalarOp,
Second,
Sub,
)
from aesara.scalar.math import Erf, Erfc, Erfinv, Log1mexp, Psi


def check_if_inputs_scalars(node):
"""Check whether all the inputs of an `Elemwise` are scalar values.

`jax.lax` or `jax.numpy` functions systematically return `TracedArrays`,
while the corresponding Python operators return concrete values when passed
concrete values. In order to be able to compile the largest number of graphs
possible we need to preserve concrete values whenever we can. We thus need
to dispatch differently the Aesara operators depending on whether the inputs
are scalars.

"""
ndims_input = [inp.type.ndim for inp in node.inputs]
are_inputs_scalars = True
for ndim in ndims_input:
try:
if ndim > 0:
are_inputs_scalars = False
except TypeError:
are_inputs_scalars = False

return are_inputs_scalars


@jax_funcify.register(ScalarOp)
def jax_funcify_ScalarOp(op, **kwargs):
def jax_funcify_ScalarOp(op, node, **kwargs):
func_name = op.nfunc_spec[0]

# We dispatch some Aesara operators to Python operators
# whenever the inputs are all scalars.
are_inputs_scalars = check_if_inputs_scalars(node)
if are_inputs_scalars:
elemwise = elemwise_scalar(op)
if elemwise is not None:
return elemwise

if "." in func_name:
jnp_func = functools.reduce(getattr, [jax] + func_name.split("."))
else:
Expand All @@ -38,6 +81,54 @@ def elemwise(*args):
return jnp_func


@functools.singledispatch
def elemwise_scalar(op):
return None


@elemwise_scalar.register(Add)
def elemwise_scalar_add(op):
def elemwise(*inputs):
return sum(inputs)

return elemwise


@elemwise_scalar.register(Mul)
def elemwise_scalar_mul(op):
import operator
from functools import reduce

def elemwise(*inputs):
return reduce(operator.mul, inputs, 1)

return elemwise


@elemwise_scalar.register(Sub)
def elemwise_scalar_sub(op):
def elemwise(x, y):
return x - y

return elemwise


@elemwise_scalar.register(IntDiv)
def elemwise_scalar_intdiv(op):
def elemwise(x, y):
return x // y

return elemwise


@elemwise_scalar.register(Mod)
def elemwise_scalar_mod(op):
def elemwise(x, y):
return x % y

return elemwise


@jax_funcify.register(Cast)
def jax_funcify_Cast(op, **kwargs):
def cast(x):
Expand All @@ -56,6 +147,14 @@ def identity(x):

@jax_funcify.register(Clip)
def jax_funcify_Clip(op, **kwargs):
"""Register the translation for the `Clip` `Op`.

Aesara's `Clip` operator operates differently from NumPy's when the
specified `min` is larger than the `max` so we cannot reuse `jax.numpy.clip`
to maintain consistency with Aesara.

"""

def clip(x, min, max):
return jnp.where(x < min, min, jnp.where(x > max, max, x))

Expand Down
34 changes: 31 additions & 3 deletions aesara/link/jax/dispatch/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,46 @@ def shape_tuple_fn(*x):
return shape_tuple_fn


SHAPE_NOT_COMPATIBLE = """JAX requires concrete values for the `shape` parameter of `jax.numpy.reshape`.
Concrete values are either constants:

>>> import aesara.tensor as at
>>> x = at.ones(6)
>>> y = x.reshape((2, 3))

Or the shape of an array:

>>> mat = at.matrix('mat')
>>> y = x.reshape(mat.shape)
"""


def assert_shape_argument_jax_compatible(shape):
"""Assert whether the current node can be JIT-compiled by JAX.

JAX can JIT-compile functions with a `shape` or `size` argument if it is
given a concrete value, i.e. either a constant or the shape of any traced
value.

"""
shape_op = shape.owner.op
if not isinstance(shape_op, (Shape, Shape_i, JAXShapeTuple)):
raise NotImplementedError(SHAPE_NOT_COMPATIBLE)


@jax_funcify.register(Reshape)
def jax_funcify_Reshape(op, node, **kwargs):

# JAX reshape only works with constant inputs, otherwise JIT fails
shape = node.inputs[1]

if isinstance(shape, Constant):
constant_shape = shape.data

def reshape(x, shape):
return jnp.reshape(x, constant_shape)

else:
assert_shape_argument_jax_compatible(shape)

def reshape(x, shape):
return jnp.reshape(x, shape)
Expand All @@ -66,10 +94,10 @@ def shape_i(x):


@jax_funcify.register(SpecifyShape)
def jax_funcify_SpecifyShape(op, **kwargs):
def jax_funcify_SpecifyShape(op, node, **kwargs):
def specifyshape(x, *shape):
assert x.ndim == len(shape)
assert jnp.all(x.shape == tuple(shape)), (
assert x.shape == tuple(shape), (
"got shape",
x.shape,
"expected",
Expand Down
79 changes: 50 additions & 29 deletions aesara/link/jax/dispatch/subtensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import jax

from aesara.link.jax.dispatch.basic import jax_funcify
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
Expand All @@ -13,46 +11,75 @@
from aesara.tensor.type_other import MakeSlice


BOOLEAN_MASK_ERROR = """JAX does not support resizing arrays with boolean
masks. In some cases, however, it is possible to re-express your model
in a form that JAX can compile:

>>> import aesara.tensor as at
>>> x_at = at.vector('x')
>>> y_at = x_at[x_at > 0].sum()

can be re-expressed as:

>>> import aesara.tensor as at
>>> x_at = at.vector('x')
>>> y_at = at.where(x_at > 0, x_at, 0).sum()
"""

DYNAMIC_SLICE_LENGTH_ERROR = """JAX does not support slicing arrays with a dynamic
slice length.
"""


def subtensor_assert_indices_jax_compatible(node, idx_list):
from aesara.graph.basic import Constant
from aesara.tensor.var import TensorVariable

ilist = indices_from_subtensor(node.inputs[1:], idx_list)
for idx in ilist:

if isinstance(idx, TensorVariable):
if idx.type.dtype == "bool":
raise NotImplementedError(BOOLEAN_MASK_ERROR)
elif isinstance(idx, slice):
for slice_arg in (idx.start, idx.stop, idx.step):
if slice_arg is not None and not isinstance(slice_arg, Constant):
raise NotImplementedError(DYNAMIC_SLICE_LENGTH_ERROR)


@jax_funcify.register(Subtensor)
@jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, **kwargs):
def jax_funcify_Subtensor(op, node, **kwargs):

idx_list = getattr(op, "idx_list", None)
subtensor_assert_indices_jax_compatible(node, idx_list)

def subtensor(x, *ilists):

def subtensor_constant(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)

if len(indices) == 1:
indices = indices[0]

return x.__getitem__(indices)

return subtensor
return subtensor_constant


@jax_funcify.register(IncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor1)
def jax_funcify_IncSubtensor(op, **kwargs):
def jax_funcify_IncSubtensor(op, node, **kwargs):

idx_list = getattr(op, "idx_list", None)

if getattr(op, "set_instead_of_inc", False):
jax_fn = getattr(jax.ops, "index_update", None)

if jax_fn is None:

def jax_fn(x, indices, y):
return x.at[indices].set(y)
def jax_fn(x, indices, y):
return x.at[indices].set(y)

else:
jax_fn = getattr(jax.ops, "index_add", None)

if jax_fn is None:

def jax_fn(x, indices, y):
return x.at[indices].add(y)
def jax_fn(x, indices, y):
return x.at[indices].add(y)

def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list)
Expand All @@ -65,23 +92,17 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):


@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op, **kwargs):
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):

if getattr(op, "set_instead_of_inc", False):
jax_fn = getattr(jax.ops, "index_update", None)

if jax_fn is None:

def jax_fn(x, indices, y):
return x.at[indices].set(y)
def jax_fn(x, indices, y):
return x.at[indices].set(y)

else:
jax_fn = getattr(jax.ops, "index_add", None)

if jax_fn is None:

def jax_fn(x, indices, y):
return x.at[indices].add(y)
def jax_fn(x, indices, y):
return x.at[indices].add(y)

def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
Expand Down
Loading