@@ -558,4 +558,188 @@ fn test_large_thread_creation() {
558558 assert_eq ! ( th_a, dbg_th1, "Thread {:?} debug format changed for th1" , rth) ;
559559 assert_eq ! ( th_b, dbg_th2, "Thread {:?} debug format changed for th2" , rth) ;
560560 }
561+
562+ // Repeat yielded continuation test now with a new aux thread
563+ // Yielding continuation test (only supported on luau)
564+ #[ cfg( feature = "luau" ) ]
565+ {
566+ mlua:: Lua :: set_fflag ( "LuauYieldableContinuations" , true ) . unwrap ( ) ;
567+ }
568+
569+ let cont_func = lua
570+ . create_function_with_continuation (
571+ |_lua, a : u64 | Ok ( a + 1 ) ,
572+ |_lua, _status, a : u64 | {
573+ println ! ( "Reached cont" ) ;
574+ Ok ( a + 2 )
575+ } ,
576+ )
577+ . expect ( "Failed to create cont_func" ) ;
578+
579+ // Ensure normal calls work still
580+ assert_eq ! (
581+ lua. load( "local cont_func = ...\n return cont_func(1)" )
582+ . call:: <u64 >( cont_func)
583+ . expect( "Failed to call cont_func" ) ,
584+ 2
585+ ) ;
586+
587+ // basic yield test before we go any further
588+ let always_yield = lua
589+ . create_function ( |lua, ( ) | lua. yield_with ( ( 42 , "69420" . to_string ( ) , 45.6 ) ) )
590+ . unwrap ( ) ;
591+
592+ let thread = lua. create_thread ( always_yield) . unwrap ( ) ;
593+ assert_eq ! (
594+ thread. resume:: <( i32 , String , f32 ) >( ( ) ) . unwrap( ) ,
595+ ( 42 , String :: from( "69420" ) , 45.6 )
596+ ) ;
597+
598+ // Trigger the continuation
599+ let cont_func = lua
600+ . create_function_with_continuation (
601+ |lua, a : u64 | lua. yield_with ( a) ,
602+ |_lua, _status, a : u64 | {
603+ println ! ( "Reached cont" ) ;
604+ Ok ( a + 39 )
605+ } ,
606+ )
607+ . expect ( "Failed to create cont_func" ) ;
608+
609+ let luau_func = lua
610+ . load (
611+ "
612+ local cont_func = ...
613+ local res = cont_func(1)
614+ return res + 1
615+ " ,
616+ )
617+ . into_function ( )
618+ . expect ( "Failed to create function" ) ;
619+
620+ let th = lua
621+ . create_thread ( luau_func)
622+ . expect ( "Failed to create luau thread" ) ;
623+
624+ let v = th
625+ . resume :: < mlua:: MultiValue > ( cont_func)
626+ . expect ( "Failed to resume" ) ;
627+ let v = th. resume :: < i32 > ( v) . expect ( "Failed to load continuation" ) ;
628+
629+ assert_eq ! ( v, 41 ) ;
630+
631+ let always_yield = lua
632+ . create_function_with_continuation (
633+ |lua, ( ) | lua. yield_with ( ( 42 , "69420" . to_string ( ) , 45.6 ) ) ,
634+ |_lua, _, mv : mlua:: MultiValue | {
635+ println ! ( "Reached second continuation" ) ;
636+ if mv. is_empty ( ) {
637+ return Ok ( mv) ;
638+ }
639+ Err ( mlua:: Error :: external ( format ! ( "a{}" , mv. len( ) ) ) )
640+ } ,
641+ )
642+ . unwrap ( ) ;
643+
644+ let thread = lua. create_thread ( always_yield) . unwrap ( ) ;
645+ let mv = thread. resume :: < mlua:: MultiValue > ( ( ) ) . unwrap ( ) ;
646+ assert ! ( thread
647+ . resume:: <String >( mv)
648+ . unwrap_err( )
649+ . to_string( )
650+ . starts_with( "a3" ) ) ;
651+
652+ let cont_func = lua
653+ . create_function_with_continuation (
654+ |lua, a : u64 | lua. yield_with ( ( a + 1 , 1 ) ) ,
655+ |lua, status, args : mlua:: MultiValue | {
656+ println ! ( "Reached cont recursive/multiple: {:?}" , args) ;
657+
658+ if args. len ( ) == 5 {
659+ if cfg ! ( any( feature = "luau" , feature = "lua52" ) ) {
660+ assert_eq ! ( status, mlua:: ContinuationStatus :: Ok ) ;
661+ } else {
662+ assert_eq ! ( status, mlua:: ContinuationStatus :: Yielded ) ;
663+ }
664+ return Ok ( 6_i32 ) ;
665+ }
666+
667+ lua. yield_with ( ( args. len ( ) + 1 , args) ) ?; // thread state becomes LEN, LEN-1... 1
668+ Ok ( 1_i32 ) // this will be ignored
669+ } ,
670+ )
671+ . expect ( "Failed to create cont_func" ) ;
672+
673+ let luau_func = lua
674+ . load (
675+ "
676+ local cont_func = ...
677+ local res = cont_func(1)
678+ return res + 1
679+ " ,
680+ )
681+ . into_function ( )
682+ . expect ( "Failed to create function" ) ;
683+ let th = lua
684+ . create_thread ( luau_func)
685+ . expect ( "Failed to create luau thread" ) ;
686+
687+ let v = th
688+ . resume :: < mlua:: MultiValue > ( cont_func)
689+ . expect ( "Failed to resume" ) ;
690+ println ! ( "v={:?}" , v) ;
691+
692+ let v = th
693+ . resume :: < mlua:: MultiValue > ( v)
694+ . expect ( "Failed to load continuation" ) ;
695+ println ! ( "v={:?}" , v) ;
696+ let v = th
697+ . resume :: < mlua:: MultiValue > ( v)
698+ . expect ( "Failed to load continuation" ) ;
699+ println ! ( "v={:?}" , v) ;
700+ let v = th
701+ . resume :: < mlua:: MultiValue > ( v)
702+ . expect ( "Failed to load continuation" ) ;
703+
704+ // (2, 1) followed by ()
705+ assert_eq ! ( v. len( ) , 2 + 3 ) ;
706+
707+ let v = th. resume :: < i32 > ( v) . expect ( "Failed to load continuation" ) ;
708+
709+ assert_eq ! ( v, 7 ) ;
710+
711+ // test panics
712+ let cont_func = lua
713+ . create_function_with_continuation (
714+ |lua, a : u64 | lua. yield_with ( a) ,
715+ |_lua, _status, _a : u64 | {
716+ panic ! ( "Reached continuation which should panic!" ) ;
717+ #[ allow( unreachable_code) ]
718+ Ok ( ( ) )
719+ } ,
720+ )
721+ . expect ( "Failed to create cont_func" ) ;
722+
723+ let luau_func = lua
724+ . load (
725+ "
726+ local cont_func = ...
727+ local ok, res = pcall(cont_func, 1)
728+ assert(not ok)
729+ return tostring(res)
730+ " ,
731+ )
732+ . into_function ( )
733+ . expect ( "Failed to create function" ) ;
734+
735+ let th = lua
736+ . create_thread ( luau_func)
737+ . expect ( "Failed to create luau thread" ) ;
738+
739+ let v = th
740+ . resume :: < mlua:: MultiValue > ( cont_func)
741+ . expect ( "Failed to resume" ) ;
742+
743+ let v = th. resume :: < String > ( v) . expect ( "Failed to load continuation" ) ;
744+ assert ! ( v. contains( "Reached continuation which should panic!" ) ) ;
561745}
0 commit comments