@@ -19,20 +19,29 @@ package test
1919import (
2020 "context"
2121 "fmt"
22+ "sync"
2223 "testing"
2324
25+ "github.com/bytedance/mockey"
2426 "github.com/cloudwego/eino/compose"
2527 "github.com/stretchr/testify/assert"
2628
29+ model "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/model"
30+ "github.com/coze-dev/coze-studio/backend/domain/workflow"
2731 "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
2832 "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
2933 compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
3034 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
3135 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
3236 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
37+ "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
3338 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
3439)
3540
41+ type interruptibleConfig struct {}
42+
43+ func (c * interruptibleConfig ) RequireCheckpoint () bool { return true }
44+
3645func TestBatch (t * testing.T ) {
3746 ctx := context .Background ()
3847
@@ -58,7 +67,7 @@ func TestBatch(t *testing.T) {
5867 lambdaNode1 := & schema.NodeSchema {
5968 Key : "lambda" ,
6069 Type : entity .NodeTypeLambda ,
61- Lambda : compose .InvokableLambda (lambda1 ),
70+ Lambda : compose .InvokableLambda (lambda1 , compose . WithLambdaType ( string ( entity . NodeTypeLambda )) ),
6271 InputSources : []* vo.FieldInfo {
6372 {
6473 Path : compose.FieldPath {"index" },
@@ -92,7 +101,7 @@ func TestBatch(t *testing.T) {
92101 lambdaNode2 := & schema.NodeSchema {
93102 Key : "index" ,
94103 Type : entity .NodeTypeLambda ,
95- Lambda : compose .InvokableLambda (lambda2 ),
104+ Lambda : compose .InvokableLambda (lambda2 , compose . WithLambdaType ( string ( entity . NodeTypeLambda )) ),
96105 InputSources : []* vo.FieldInfo {
97106 {
98107 Path : compose.FieldPath {"index" },
@@ -109,7 +118,7 @@ func TestBatch(t *testing.T) {
109118 lambdaNode3 := & schema.NodeSchema {
110119 Key : "consumer" ,
111120 Type : entity .NodeTypeLambda ,
112- Lambda : compose .InvokableLambda (lambda3 ),
121+ Lambda : compose .InvokableLambda (lambda3 , compose . WithLambdaType ( string ( entity . NodeTypeLambda )) ),
113122 InputSources : []* vo.FieldInfo {
114123 {
115124 Path : compose.FieldPath {"consumer_1" },
@@ -251,7 +260,7 @@ func TestBatch(t *testing.T) {
251260 parentLambdaNode := & schema.NodeSchema {
252261 Key : "parent_predecessor_1" ,
253262 Type : entity .NodeTypeLambda ,
254- Lambda : compose .InvokableLambda (parentLambda ),
263+ Lambda : compose .InvokableLambda (parentLambda , compose . WithLambdaType ( string ( entity . NodeTypeLambda )) ),
255264 }
256265
257266 ws := & schema.WorkflowSchema {
@@ -349,3 +358,318 @@ func TestBatch(t *testing.T) {
349358 })
350359 assert .ErrorContains (t , err , "is too large" )
351360}
361+
362+ type mockRepo struct {
363+ workflow.Repository
364+ mu sync.Mutex
365+ events []* entity.InterruptEvent
366+ cp map [string ][]byte
367+ }
368+
369+ func (m * mockRepo ) GenID (ctx context.Context ) (int64 , error ) {
370+ return 10001 , nil
371+ }
372+
373+ func (m * mockRepo ) ListInterruptEvents (ctx context.Context , wfExeID int64 ) ([]* entity.InterruptEvent , error ) {
374+ m .mu .Lock ()
375+ defer m .mu .Unlock ()
376+ return m .events , nil
377+ }
378+
379+ func (m * mockRepo ) SaveInterruptEvents (ctx context.Context , wfExeID int64 , events []* entity.InterruptEvent ) error {
380+ m .mu .Lock ()
381+ defer m .mu .Unlock ()
382+ m .events = append (m .events , events ... )
383+ return nil
384+ }
385+
386+ func (m * mockRepo ) CreateWorkflowExecution (ctx context.Context , exe * entity.WorkflowExecution ) error {
387+ return nil
388+ }
389+
390+ func (m * mockRepo ) UpdateWorkflowExecution (ctx context.Context , exe * entity.WorkflowExecution , statuses []entity.WorkflowExecuteStatus ) (int64 , entity.WorkflowExecuteStatus , error ) {
391+ return 1 , 0 , nil
392+ }
393+
394+ func (m * mockRepo ) TryLockWorkflowExecution (ctx context.Context , executeID int64 , eventID int64 ) (bool , entity.WorkflowExecuteStatus , error ) {
395+ return true , 0 , nil
396+ }
397+
398+ func (m * mockRepo ) CreateNodeExecution (ctx context.Context , exe * entity.NodeExecution ) error {
399+ return nil
400+ }
401+
402+ func (m * mockRepo ) UpdateNodeExecution (ctx context.Context , exe * entity.NodeExecution ) error {
403+ return nil
404+ }
405+
406+ func (m * mockRepo ) GetFirstInterruptEvent (ctx context.Context , wfExeID int64 ) (* entity.InterruptEvent , bool , error ) {
407+ m .mu .Lock ()
408+ defer m .mu .Unlock ()
409+ if len (m .events ) > 0 {
410+ return m .events [0 ], true , nil
411+ }
412+ return nil , false , nil
413+ }
414+
415+ func (m * mockRepo ) PopFirstInterruptEvent (ctx context.Context , wfExeID int64 ) (* entity.InterruptEvent , bool , error ) {
416+ m .mu .Lock ()
417+ defer m .mu .Unlock ()
418+ if len (m .events ) > 0 {
419+ e := m .events [0 ]
420+ m .events = m .events [1 :]
421+ return e , true , nil
422+ }
423+ return nil , false , nil
424+ }
425+
426+ func (m * mockRepo ) UpdateFirstInterruptEvent (ctx context.Context , wfExeID int64 , event * entity.InterruptEvent ) error {
427+ m .mu .Lock ()
428+ defer m .mu .Unlock ()
429+ if len (m .events ) > 0 {
430+ m .events [0 ] = event
431+ }
432+ return nil
433+ }
434+
435+ func (m * mockRepo ) GetWorkflowCancelFlag (ctx context.Context , wfExeID int64 ) (bool , error ) {
436+ return false , nil
437+ }
438+
439+ func (m * mockRepo ) Get (ctx context.Context , executeID string ) ([]byte , bool , error ) {
440+ m .mu .Lock ()
441+ defer m .mu .Unlock ()
442+ if m .cp == nil {
443+ return nil , false , nil
444+ }
445+ cp , ok := m .cp [executeID ]
446+ return cp , ok , nil
447+ }
448+
449+ func (m * mockRepo ) Set (ctx context.Context , executeID string , cp []byte ) error {
450+ m .mu .Lock ()
451+ defer m .mu .Unlock ()
452+ if m .cp == nil {
453+ m .cp = make (map [string ][]byte )
454+ }
455+ m .cp [executeID ] = cp
456+ return nil
457+ }
458+
459+ func TestBatch_Nested_Interrupt (t * testing.T ) {
460+ ctx := context .Background ()
461+
462+ var callCount int
463+ var mu sync.Mutex
464+
465+ // The innermost node that will interrupt inside the SubWorkflow
466+ lambdaNode := & schema.NodeSchema {
467+ Key : "lambda" ,
468+ Type : entity .NodeTypeLambda ,
469+ Configs : & interruptibleConfig {},
470+ Lambda : compose .InvokableLambda (func (ctx context.Context , in map [string ]any ) (map [string ]any , error ) {
471+ mu .Lock ()
472+ callCount ++
473+ currentCount := callCount
474+ mu .Unlock ()
475+
476+ if in ["resume_data" ] == "a" || in ["resume_data" ] == "b" {
477+ interruptEvent := & entity.InterruptEvent {
478+ ID : int64 (currentCount ),
479+ NodeKey : "lambda" ,
480+ EventType : entity .InterruptEventInput ,
481+ InterruptData : "{}" ,
482+ }
483+ return nil , compose .NewInterruptAndRerunErr (interruptEvent )
484+ }
485+ return map [string ]any {"output" : "ok" }, nil
486+ }, compose .WithLambdaType (string (entity .NodeTypeLambda ))),
487+ InputSources : []* vo.FieldInfo {
488+ {
489+ Path : compose.FieldPath {"resume_data" },
490+ Source : vo.FieldSource {
491+ Ref : & vo.Reference {
492+ FromNodeKey : entity .EntryNodeKey ,
493+ FromPath : compose.FieldPath {"resume_data" },
494+ },
495+ },
496+ },
497+ },
498+ }
499+
500+ subWfSchema := & schema.WorkflowSchema {
501+ Nodes : []* schema.NodeSchema {
502+ {Key : entity .EntryNodeKey , Type : entity .NodeTypeEntry , Configs : & entry.Config {}},
503+ lambdaNode ,
504+ {Key : entity .ExitNodeKey , Type : entity .NodeTypeExit , Configs : & exit.Config {TerminatePlan : vo .ReturnVariables },
505+ InputSources : []* vo.FieldInfo {
506+ {
507+ Path : compose.FieldPath {"output" },
508+ Source : vo.FieldSource {
509+ Ref : & vo.Reference {
510+ FromNodeKey : "lambda" ,
511+ FromPath : compose.FieldPath {"output" },
512+ },
513+ },
514+ },
515+ },
516+ },
517+ },
518+ Connections : []* schema.Connection {
519+ {FromNode : entity .EntryNodeKey , ToNode : "lambda" },
520+ {FromNode : "lambda" , ToNode : entity .ExitNodeKey },
521+ },
522+ }
523+ subWfSchema .Init ()
524+
525+ subWfNode := & schema.NodeSchema {
526+ Key : "sub_workflow" ,
527+ Type : entity .NodeTypeSubWorkflow ,
528+ Configs : & subworkflow.Config {WorkflowID : 100 },
529+ SubWorkflowSchema : subWfSchema ,
530+ SubWorkflowBasic : & entity.WorkflowBasic {ID : 100 , Version : "1" },
531+ InputSources : []* vo.FieldInfo {
532+ {
533+ Path : compose.FieldPath {"resume_data" },
534+ Source : vo.FieldSource {
535+ Ref : & vo.Reference {
536+ FromNodeKey : "outer_batch" ,
537+ FromPath : compose.FieldPath {"outer_array" },
538+ },
539+ },
540+ },
541+ },
542+ OutputSources : []* vo.FieldInfo {
543+ {
544+ Path : compose.FieldPath {"inner_output" },
545+ Source : vo.FieldSource {
546+ Ref : & vo.Reference {
547+ FromNodeKey : entity .ExitNodeKey , // SubWorkflow's exit node
548+ FromPath : compose.FieldPath {"output" },
549+ },
550+ },
551+ },
552+ },
553+ }
554+
555+ outerBatch := & schema.NodeSchema {
556+ Key : "outer_batch" ,
557+ Type : entity .NodeTypeBatch ,
558+ Configs : & batch.Config {},
559+ InputSources : []* vo.FieldInfo {
560+ {
561+ Path : compose.FieldPath {"outer_array" },
562+ Source : vo.FieldSource {
563+ Ref : & vo.Reference {
564+ FromNodeKey : entity .EntryNodeKey ,
565+ FromPath : compose.FieldPath {"outer_array" },
566+ },
567+ },
568+ },
569+ {
570+ Path : compose.FieldPath {batch .ConcurrentSizeKey },
571+ Source : vo.FieldSource {Val : int64 (2 )},
572+ },
573+ {
574+ Path : compose.FieldPath {batch .MaxBatchSizeKey },
575+ Source : vo.FieldSource {Val : int64 (2 )},
576+ },
577+ },
578+ InputTypes : map [string ]* vo.TypeInfo {
579+ "outer_array" : {
580+ Type : vo .DataTypeArray ,
581+ ElemTypeInfo : & vo.TypeInfo {Type : vo .DataTypeString },
582+ },
583+ },
584+ OutputSources : []* vo.FieldInfo {
585+ {
586+ Path : compose.FieldPath {"final_output" },
587+ Source : vo.FieldSource {
588+ Ref : & vo.Reference {
589+ FromNodeKey : "sub_workflow" ,
590+ FromPath : compose.FieldPath {"output" },
591+ },
592+ },
593+ },
594+ },
595+ }
596+
597+ ws := & schema.WorkflowSchema {
598+ Nodes : []* schema.NodeSchema {
599+ {Key : entity .EntryNodeKey , Type : entity .NodeTypeEntry , Configs : & entry.Config {}},
600+ outerBatch ,
601+ subWfNode ,
602+ {Key : entity .ExitNodeKey , Type : entity .NodeTypeExit , Configs : & exit.Config {TerminatePlan : vo .ReturnVariables }},
603+ },
604+ Hierarchy : map [vo.NodeKey ]vo.NodeKey {
605+ "sub_workflow" : "outer_batch" ,
606+ },
607+ Connections : []* schema.Connection {
608+ {FromNode : entity .EntryNodeKey , ToNode : "outer_batch" },
609+ {FromNode : "outer_batch" , ToNode : "sub_workflow" },
610+ {FromNode : "sub_workflow" , ToNode : "outer_batch" },
611+ {FromNode : "outer_batch" , ToNode : entity .ExitNodeKey },
612+ },
613+ }
614+
615+ ws .Init ()
616+ basic := & entity.WorkflowBasic {ID : 1 , Version : "1" }
617+
618+ // MOCK Repository globally
619+ myRepo := & mockRepo {}
620+ mockPatch := mockey .Mock (workflow .GetRepository ).To (func () workflow.Repository {
621+ return myRepo
622+ }).Build ()
623+ defer mockPatch .UnPatch ()
624+
625+ // 1. Initial run: 2 items in outer array, both should interrupt
626+ initialRunner := compose2 .NewWorkflowRunner (basic , ws , model.ExecuteConfig {})
627+ initialCtx , executeID , opts , _ , err := initialRunner .Prepare (ctx )
628+ assert .NoError (t , err )
629+
630+ wf , err := compose2 .NewWorkflow (initialCtx , ws , compose2 .WithIDAsName (basic .ID ))
631+ assert .NoError (t , err )
632+
633+ _ , err = wf .Runner .Invoke (initialCtx , map [string ]any {
634+ "outer_array" : []any {"a" , "b" },
635+ "resume_data" : nil ,
636+ }, opts ... )
637+
638+ t .Logf ("Initial run error: %v" , err )
639+ assert .Error (t , err )
640+ if callCount != 2 {
641+ t .Fatalf ("Expected callCount 2, got %d. Error was: %v" , callCount , err )
642+ }
643+
644+ // 2. Resume the first event returned by GetFirstInterruptEvent
645+ repo := workflow .GetRepository ()
646+ event0 , _ , _ := repo .GetFirstInterruptEvent (ctx , executeID )
647+ assert .NotNil (t , event0 , "Event 0 should not be nil" )
648+ if event0 == nil {
649+ t .Fatal ("Event0 is nil, cannot proceed" )
650+ }
651+
652+ // Create a new runner for resumption
653+ resumeRunner := compose2 .NewWorkflowRunner (basic , ws , model.ExecuteConfig {},
654+ compose2 .WithResumeReq (& entity.ResumeRequest {
655+ ExecuteID : executeID ,
656+ EventID : event0 .ID ,
657+ ResumeData : "resumed" ,
658+ }))
659+
660+ resumeCtx , _ , resumeOpts , _ , err := resumeRunner .Prepare (ctx )
661+ assert .NoError (t , err )
662+
663+ // Invoke resumption
664+ _ , err = wf .Runner .Invoke (resumeCtx , map [string ]any {
665+ "outer_array" : []any {nil , nil },
666+ "resume_data" : "resumed" ,
667+ }, resumeOpts ... )
668+
669+ assert .Error (t , err ) // Should still be interrupted at index 1
670+
671+ // CRITICAL ASSERTION:
672+ // If the fix works, callCount should be 3 (2 from initial + 1 from resumed index 0).
673+ // If index 1 was NOT skipped (the bug), callCount would be 4.
674+ assert .Equal (t , 3 , callCount , "Index 1 of the outer batch should have been skipped" )
675+ }
0 commit comments