-
-
Notifications
You must be signed in to change notification settings - Fork 152
Fix the JAX Subtensor and IncSubtensor dispatcher
#1338
Conversation
b53c174 to
7fb0948
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #1338 +/- ##
==========================================
+ Coverage 74.36% 74.66% +0.29%
==========================================
Files 177 177
Lines 49066 49050 -16
Branches 10379 10400 +21
==========================================
+ Hits 36488 36623 +135
+ Misses 10285 10131 -154
- Partials 2293 2296 +3
|
721e087 to
8814219
Compare
|
Here is another issue I would like to fix in this PR. When operating on scalars, import jax
import jax.numpy as jnp
@jax.jit
def fn():
a = 3 + 2
b = jnp.ones(3) + jnp.ones(3)
print(a)
print(b)
return a
@jax.jit
def fn_lax_add():
a = jax.lax.add(3, 2)
print(a)
return a
fn()
# 5
# Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=0/1)>
fn_lax_add()
# Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>import jax
@jax.jit
def fn():
a = 3 * 2
print(a)
return a
@jax.jit
def fn_lax_mul():
a = jax.lax.mul(3, 2)
print(a)
return a
fn()
# 6
fn_lax_mul()
# Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>import jax
@jax.jit
def fn():
a = 3 / 2
print(a)
return a
@jax.jit
def fn_lax_div():
a = jax.lax.div(3, 2)
print(a)
return a
fn()
# 1.5
fn_lax_div()
# Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>Since this can cause problems down the line when the result of the operation is passed as a |
|
The problem is illustrated by the following MWE: import aesara
import aesara.tensor as at
import numpy as np
x = at.matrix('x')
shape = x.shape[0] + x.shape[1]
out = at.ones(shape)
fn = aesara.function((x,), out, mode="JAX")
try:
fn(np.ones((2,3)))
except Exception as e:
print(e)
# Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>,).
# If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
# Apply node that caused the error: Alloc(TensorConstant{1.0}, Elemwise{Add}[(0, 0)].0)
# Toposort index: 3
# Inputs types: [TensorType(float64, ()), TensorType(int64, ())]
# Inputs shapes: [(2, 3)]
# Inputs strides: [(24, 8)]
# Inputs values: ['not shown']
# Outputs clients: [['output']]While the equivalent JAX implementation: import jax
import jax.numpy as jnp
import numpy as np
@jax.jit
def fn(x):
shape = x.shape[0] + x.shape[1]
return jnp.ones(shape)
print(fn(np.ones((2,3))))
# [1. 1. 1. 1. 1.] |
d4d7ce9 to
8267fff
Compare
8267fff to
4dea813
Compare
49b4c3d to
b0de7b9
Compare
|
I have now fixed everything all the issues I identified in #1202 that were not related to the There is one last error that I do not understand in self = <aesara.tensor.shape.Reshape object at 0x7fbb71e8e9e0>
node = Reshape{2}(a, JAXShapeTuple.0), inp = [array([1., 2., 3., 4.]), None]
out_ = [[None]], params = Params(ndim:int32:2)
def perform(self, node, inp, out_, params):
x, shp = inp
(out,) = out_
> if len(shp) != self.ndim:
E TypeError: object of type 'NoneType' has no len()Do you see what might be the issue @brandonwillard ? |
Yeah, there shouldn't be a |
b0de7b9 to
2a3baa5
Compare
06382af to
14bb49c
Compare
14bb49c to
d937ab5
Compare
bd697dc to
78d2dbf
Compare
|
Ready for review. |
In this PR I fix the outstanding issues with the
SubtensorandIncSubtensordispatchers in the JAX backend, and a few other things that came up in #1202.jax.lax.dynamic_sliceProgress
RandomVariabledispatcher #1284)start,stop,stepinARange, raise if dynamicSubtensorif slicing with dynamic length;Elemwiseoperations on scalar values to Python operatorsIncSubtensorjax.numpy.copydirectlyTensorFromScalaras a pass-throughThis one requires more thoughts on scalars in JAX vs scalars in Aesara.
Reshapeimplementation to make sure theshapeparameter is passed concrete values (waiting for Fix the JAXRandomVariabledispatcher #1284)Allow combination of concrete values using Python operators asThey already areshapeandsizearguments (we may need a rewrite and custom Ops likeJAXPythonAddto make this easier).This is a spinoff of #1202.