Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 2e7b3a8

Browse files
committed
Support mit-mots in the JAX backend
1 parent e0850cc commit 2e7b3a8

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

aesara/link/jax/dispatch/scan.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
def jax_funcify_Scan(op, node, **kwargs):
1414
scan_inner_fn = jax_funcify(op.fgraph)
1515
input_taps = {
16+
"mit_mot": op.info.mit_mot_in_slices,
1617
"mit_sot": op.info.mit_sot_in_slices,
1718
"sit_sot": op.info.sit_sot_in_slices,
1819
"nit_sot": op.info.sit_sot_in_slices,
@@ -32,9 +33,6 @@ def parse_outer_inputs(outer_inputs):
3233
"shared": list(op.outer_shared(outer_inputs)),
3334
"non_sequences": list(op.outer_non_seqs(outer_inputs)),
3435
}
35-
if len(outer_in["mit_mot"]) > 0:
36-
raise NotImplementedError("mit-mot not supported")
37-
3836
return outer_in
3937

4038
if op.info.as_while:
@@ -323,7 +321,7 @@ def build_jax_scan_inputs(outer_in: Dict):
323321
sequences = outer_in["sequences"]
324322
init_carry = {
325323
name: outer_in[name]
326-
for name in ["mit_sot", "sit_sot", "shared", "non_sequences"]
324+
for name in ["mit_mot", "mit_sot", "sit_sot", "shared", "non_sequences"]
327325
}
328326
init_carry["step"] = 0
329327
return n_steps, sequences, init_carry
@@ -340,7 +338,7 @@ def build_inner_outputs_map(outer_in):
340338
[+ while-condition]
341339
342340
"""
343-
inner_outputs_names = ["mit_sot", "sit_sot", "nit_sot", "shared"]
341+
inner_outputs_names = ["mit_mot", "mit_sot", "sit_sot", "nit_sot", "shared"]
344342

345343
offset = 0
346344
inner_output_idx = defaultdict(list)
@@ -415,6 +413,9 @@ def scan_inner_in_args(carry, x):
415413
current_step = carry["step"]
416414

417415
inner_in_seqs = x
416+
inner_in_mit_mot = from_carry_storage(
417+
carry["mit_mot"], current_step, input_taps["mit_mot"]
418+
)
418419
inner_in_mit_sot = from_carry_storage(
419420
carry["mit_sot"], current_step, input_taps["mit_sot"]
420421
)
@@ -427,6 +428,7 @@ def scan_inner_in_args(carry, x):
427428
return sum(
428429
[
429430
inner_in_seqs,
431+
inner_in_mit_mot,
430432
inner_in_mit_sot,
431433
inner_in_sit_sot,
432434
inner_in_shared,
@@ -439,6 +441,7 @@ def scan_new_carry(carry, inner_outputs):
439441
"""Create a new carry value from the values returned by the inner function (inner-outputs)."""
440442
step = carry["step"]
441443
new_carry = {
444+
"mit_mot": [],
442445
"mit_sot": [],
443446
"sit_sot": [],
444447
"shared": [],
@@ -452,6 +455,14 @@ def scan_new_carry(carry, inner_outputs):
452455
]
453456
new_carry["shared"] = shared_inner_outputs
454457

458+
if "mit_mot" in inner_output_idx:
459+
mit_mot_inner_outputs = [
460+
inner_outputs[idx] for idx in inner_output_idx["mit_mot"]
461+
]
462+
new_carry["mit_mot"] = to_carry_storage(
463+
mit_mot_inner_outputs, carry["mit_mot"], step, input_taps["mit_mot"]
464+
)
465+
455466
if "mit_sot" in inner_output_idx:
456467
mit_sot_inner_outputs = [
457468
inner_outputs[idx] for idx in inner_output_idx["mit_sot"]
@@ -486,6 +497,10 @@ def scan_new_outputs(inner_outputs):
486497
487498
"""
488499
outer_outputs = []
500+
if "mit_mot" in inner_output_idx:
501+
outer_outputs.append(
502+
[inner_outputs[idx] for idx in inner_output_idx["mit_mot"]]
503+
)
489504
if "mit_sot" in inner_output_idx:
490505
outer_outputs.append(
491506
[inner_outputs[idx] for idx in inner_output_idx["mit_sot"]]

tests/link/jax/test_scan.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from packaging.version import parse as version_parse
44

55
import aesara.tensor as at
6-
from aesara import function
6+
from aesara import function, grad
77
from aesara.compile.mode import Mode
88
from aesara.configdefaults import config
99
from aesara.graph.fg import FunctionGraph
@@ -23,6 +23,7 @@
2323
# Disable all optimizations
2424
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
2525
jax_mode = Mode(JAXLinker(), opts)
26+
py_mode = Mode("py", opts)
2627

2728

2829
def test_while_cannnot_use_all_outputs():
@@ -66,8 +67,8 @@ def test_sit_sot():
6667
n_steps=3,
6768
)
6869

69-
jax_fn = function((a_at,), res, updates=updates, mode="JAX")
7070
fn = function((a_at,), res, updates=updates)
71+
jax_fn = function((a_at,), res, updates=updates, mode="JAX")
7172
assert np.allclose(fn(1.0), jax_fn(1.0))
7273

7374

@@ -413,3 +414,29 @@ def power_step(prior_result, x):
413414

414415
for output_jax, output in zip(jax_res, res):
415416
assert np.allclose(jax_res, res)
417+
418+
419+
def test_mitmots_basic():
420+
421+
init_x = at.dvector()
422+
seq = at.dvector()
423+
424+
def inner_fct(seq, state_old, state_current):
425+
return state_old * 2 + state_current + seq
426+
427+
out, _ = scan(
428+
inner_fct, sequences=seq, outputs_info={"initial": init_x, "taps": [-2, -1]}
429+
)
430+
431+
g_outs = grad(out.sum(), [seq, init_x])
432+
433+
out_fg = FunctionGraph([seq, init_x], g_outs)
434+
435+
seq_val = np.arange(3)
436+
init_x_val = np.r_[-2, -1]
437+
(seq_val, init_x_val)
438+
439+
fn = function(out_fg.inputs, out_fg.outputs)
440+
jax_fn = function(out_fg.inputs, out_fg.outputs, mode="JAX")
441+
print(fn(seq_val, init_x_val))
442+
print(jax_fn(seq_val, init_x_val))

0 commit comments

Comments
 (0)