1313def jax_funcify_Scan (op , node , ** kwargs ):
1414 scan_inner_fn = jax_funcify (op .fgraph )
1515 input_taps = {
16+ "mit_mot" : op .info .mit_mot_in_slices ,
1617 "mit_sot" : op .info .mit_sot_in_slices ,
1718 "sit_sot" : op .info .sit_sot_in_slices ,
1819 "nit_sot" : op .info .sit_sot_in_slices ,
@@ -32,9 +33,6 @@ def parse_outer_inputs(outer_inputs):
3233 "shared" : list (op .outer_shared (outer_inputs )),
3334 "non_sequences" : list (op .outer_non_seqs (outer_inputs )),
3435 }
35- if len (outer_in ["mit_mot" ]) > 0 :
36- raise NotImplementedError ("mit-mot not supported" )
37-
3836 return outer_in
3937
4038 if op .info .as_while :
@@ -323,7 +321,7 @@ def build_jax_scan_inputs(outer_in: Dict):
323321 sequences = outer_in ["sequences" ]
324322 init_carry = {
325323 name : outer_in [name ]
326- for name in ["mit_sot" , "sit_sot" , "shared" , "non_sequences" ]
324+ for name in ["mit_mot" , " mit_sot" , "sit_sot" , "shared" , "non_sequences" ]
327325 }
328326 init_carry ["step" ] = 0
329327 return n_steps , sequences , init_carry
@@ -340,7 +338,7 @@ def build_inner_outputs_map(outer_in):
340338 [+ while-condition]
341339
342340 """
343- inner_outputs_names = ["mit_sot" , "sit_sot" , "nit_sot" , "shared" ]
341+ inner_outputs_names = ["mit_mot" , " mit_sot" , "sit_sot" , "nit_sot" , "shared" ]
344342
345343 offset = 0
346344 inner_output_idx = defaultdict (list )
@@ -415,6 +413,9 @@ def scan_inner_in_args(carry, x):
415413 current_step = carry ["step" ]
416414
417415 inner_in_seqs = x
416+ inner_in_mit_mot = from_carry_storage (
417+ carry ["mit_mot" ], current_step , input_taps ["mit_mot" ]
418+ )
418419 inner_in_mit_sot = from_carry_storage (
419420 carry ["mit_sot" ], current_step , input_taps ["mit_sot" ]
420421 )
@@ -427,6 +428,7 @@ def scan_inner_in_args(carry, x):
427428 return sum (
428429 [
429430 inner_in_seqs ,
431+ inner_in_mit_mot ,
430432 inner_in_mit_sot ,
431433 inner_in_sit_sot ,
432434 inner_in_shared ,
@@ -439,6 +441,7 @@ def scan_new_carry(carry, inner_outputs):
439441 """Create a new carry value from the values returned by the inner function (inner-outputs)."""
440442 step = carry ["step" ]
441443 new_carry = {
444+ "mit_mot" : [],
442445 "mit_sot" : [],
443446 "sit_sot" : [],
444447 "shared" : [],
@@ -452,6 +455,14 @@ def scan_new_carry(carry, inner_outputs):
452455 ]
453456 new_carry ["shared" ] = shared_inner_outputs
454457
458+ if "mit_mot" in inner_output_idx :
459+ mit_mot_inner_outputs = [
460+ inner_outputs [idx ] for idx in inner_output_idx ["mit_mot" ]
461+ ]
462+ new_carry ["mit_mot" ] = to_carry_storage (
463+ mit_mot_inner_outputs , carry ["mit_mot" ], step , input_taps ["mit_mot" ]
464+ )
465+
455466 if "mit_sot" in inner_output_idx :
456467 mit_sot_inner_outputs = [
457468 inner_outputs [idx ] for idx in inner_output_idx ["mit_sot" ]
@@ -486,6 +497,10 @@ def scan_new_outputs(inner_outputs):
486497
487498 """
488499 outer_outputs = []
500+ if "mit_mot" in inner_output_idx :
501+ outer_outputs .append (
502+ [inner_outputs [idx ] for idx in inner_output_idx ["mit_mot" ]]
503+ )
489504 if "mit_sot" in inner_output_idx :
490505 outer_outputs .append (
491506 [inner_outputs [idx ] for idx in inner_output_idx ["mit_sot" ]]
0 commit comments