Skip to content

Commit 4f786d6

Browse files
fix(workflow): skip un-resumed items during nested batch/loop interrupt-resume (#2644)
1 parent 8de249d commit 4f786d6

File tree

6 files changed

+591
-9
lines changed

6 files changed

+591
-9
lines changed

backend/domain/workflow/internal/compose/test/batch_test.go

Lines changed: 328 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,29 @@ package test
1919
import (
2020
"context"
2121
"fmt"
22+
"sync"
2223
"testing"
2324

25+
"github.com/bytedance/mockey"
2426
"github.com/cloudwego/eino/compose"
2527
"github.com/stretchr/testify/assert"
2628

29+
model "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/model"
30+
"github.com/coze-dev/coze-studio/backend/domain/workflow"
2731
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
2832
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
2933
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
3034
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
3135
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
3236
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
37+
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
3338
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
3439
)
3540

41+
type interruptibleConfig struct{}
42+
43+
func (c *interruptibleConfig) RequireCheckpoint() bool { return true }
44+
3645
func TestBatch(t *testing.T) {
3746
ctx := context.Background()
3847

@@ -58,7 +67,7 @@ func TestBatch(t *testing.T) {
5867
lambdaNode1 := &schema.NodeSchema{
5968
Key: "lambda",
6069
Type: entity.NodeTypeLambda,
61-
Lambda: compose.InvokableLambda(lambda1),
70+
Lambda: compose.InvokableLambda(lambda1, compose.WithLambdaType(string(entity.NodeTypeLambda))),
6271
InputSources: []*vo.FieldInfo{
6372
{
6473
Path: compose.FieldPath{"index"},
@@ -92,7 +101,7 @@ func TestBatch(t *testing.T) {
92101
lambdaNode2 := &schema.NodeSchema{
93102
Key: "index",
94103
Type: entity.NodeTypeLambda,
95-
Lambda: compose.InvokableLambda(lambda2),
104+
Lambda: compose.InvokableLambda(lambda2, compose.WithLambdaType(string(entity.NodeTypeLambda))),
96105
InputSources: []*vo.FieldInfo{
97106
{
98107
Path: compose.FieldPath{"index"},
@@ -109,7 +118,7 @@ func TestBatch(t *testing.T) {
109118
lambdaNode3 := &schema.NodeSchema{
110119
Key: "consumer",
111120
Type: entity.NodeTypeLambda,
112-
Lambda: compose.InvokableLambda(lambda3),
121+
Lambda: compose.InvokableLambda(lambda3, compose.WithLambdaType(string(entity.NodeTypeLambda))),
113122
InputSources: []*vo.FieldInfo{
114123
{
115124
Path: compose.FieldPath{"consumer_1"},
@@ -251,7 +260,7 @@ func TestBatch(t *testing.T) {
251260
parentLambdaNode := &schema.NodeSchema{
252261
Key: "parent_predecessor_1",
253262
Type: entity.NodeTypeLambda,
254-
Lambda: compose.InvokableLambda(parentLambda),
263+
Lambda: compose.InvokableLambda(parentLambda, compose.WithLambdaType(string(entity.NodeTypeLambda))),
255264
}
256265

257266
ws := &schema.WorkflowSchema{
@@ -349,3 +358,318 @@ func TestBatch(t *testing.T) {
349358
})
350359
assert.ErrorContains(t, err, "is too large")
351360
}
361+
362+
type mockRepo struct {
363+
workflow.Repository
364+
mu sync.Mutex
365+
events []*entity.InterruptEvent
366+
cp map[string][]byte
367+
}
368+
369+
func (m *mockRepo) GenID(ctx context.Context) (int64, error) {
370+
return 10001, nil
371+
}
372+
373+
func (m *mockRepo) ListInterruptEvents(ctx context.Context, wfExeID int64) ([]*entity.InterruptEvent, error) {
374+
m.mu.Lock()
375+
defer m.mu.Unlock()
376+
return m.events, nil
377+
}
378+
379+
func (m *mockRepo) SaveInterruptEvents(ctx context.Context, wfExeID int64, events []*entity.InterruptEvent) error {
380+
m.mu.Lock()
381+
defer m.mu.Unlock()
382+
m.events = append(m.events, events...)
383+
return nil
384+
}
385+
386+
func (m *mockRepo) CreateWorkflowExecution(ctx context.Context, exe *entity.WorkflowExecution) error {
387+
return nil
388+
}
389+
390+
func (m *mockRepo) UpdateWorkflowExecution(ctx context.Context, exe *entity.WorkflowExecution, statuses []entity.WorkflowExecuteStatus) (int64, entity.WorkflowExecuteStatus, error) {
391+
return 1, 0, nil
392+
}
393+
394+
func (m *mockRepo) TryLockWorkflowExecution(ctx context.Context, executeID int64, eventID int64) (bool, entity.WorkflowExecuteStatus, error) {
395+
return true, 0, nil
396+
}
397+
398+
func (m *mockRepo) CreateNodeExecution(ctx context.Context, exe *entity.NodeExecution) error {
399+
return nil
400+
}
401+
402+
func (m *mockRepo) UpdateNodeExecution(ctx context.Context, exe *entity.NodeExecution) error {
403+
return nil
404+
}
405+
406+
func (m *mockRepo) GetFirstInterruptEvent(ctx context.Context, wfExeID int64) (*entity.InterruptEvent, bool, error) {
407+
m.mu.Lock()
408+
defer m.mu.Unlock()
409+
if len(m.events) > 0 {
410+
return m.events[0], true, nil
411+
}
412+
return nil, false, nil
413+
}
414+
415+
func (m *mockRepo) PopFirstInterruptEvent(ctx context.Context, wfExeID int64) (*entity.InterruptEvent, bool, error) {
416+
m.mu.Lock()
417+
defer m.mu.Unlock()
418+
if len(m.events) > 0 {
419+
e := m.events[0]
420+
m.events = m.events[1:]
421+
return e, true, nil
422+
}
423+
return nil, false, nil
424+
}
425+
426+
func (m *mockRepo) UpdateFirstInterruptEvent(ctx context.Context, wfExeID int64, event *entity.InterruptEvent) error {
427+
m.mu.Lock()
428+
defer m.mu.Unlock()
429+
if len(m.events) > 0 {
430+
m.events[0] = event
431+
}
432+
return nil
433+
}
434+
435+
func (m *mockRepo) GetWorkflowCancelFlag(ctx context.Context, wfExeID int64) (bool, error) {
436+
return false, nil
437+
}
438+
439+
func (m *mockRepo) Get(ctx context.Context, executeID string) ([]byte, bool, error) {
440+
m.mu.Lock()
441+
defer m.mu.Unlock()
442+
if m.cp == nil {
443+
return nil, false, nil
444+
}
445+
cp, ok := m.cp[executeID]
446+
return cp, ok, nil
447+
}
448+
449+
func (m *mockRepo) Set(ctx context.Context, executeID string, cp []byte) error {
450+
m.mu.Lock()
451+
defer m.mu.Unlock()
452+
if m.cp == nil {
453+
m.cp = make(map[string][]byte)
454+
}
455+
m.cp[executeID] = cp
456+
return nil
457+
}
458+
459+
func TestBatch_Nested_Interrupt(t *testing.T) {
460+
ctx := context.Background()
461+
462+
var callCount int
463+
var mu sync.Mutex
464+
465+
// The innermost node that will interrupt inside the SubWorkflow
466+
lambdaNode := &schema.NodeSchema{
467+
Key: "lambda",
468+
Type: entity.NodeTypeLambda,
469+
Configs: &interruptibleConfig{},
470+
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (map[string]any, error) {
471+
mu.Lock()
472+
callCount++
473+
currentCount := callCount
474+
mu.Unlock()
475+
476+
if in["resume_data"] == "a" || in["resume_data"] == "b" {
477+
interruptEvent := &entity.InterruptEvent{
478+
ID: int64(currentCount),
479+
NodeKey: "lambda",
480+
EventType: entity.InterruptEventInput,
481+
InterruptData: "{}",
482+
}
483+
return nil, compose.NewInterruptAndRerunErr(interruptEvent)
484+
}
485+
return map[string]any{"output": "ok"}, nil
486+
}, compose.WithLambdaType(string(entity.NodeTypeLambda))),
487+
InputSources: []*vo.FieldInfo{
488+
{
489+
Path: compose.FieldPath{"resume_data"},
490+
Source: vo.FieldSource{
491+
Ref: &vo.Reference{
492+
FromNodeKey: entity.EntryNodeKey,
493+
FromPath: compose.FieldPath{"resume_data"},
494+
},
495+
},
496+
},
497+
},
498+
}
499+
500+
subWfSchema := &schema.WorkflowSchema{
501+
Nodes: []*schema.NodeSchema{
502+
{Key: entity.EntryNodeKey, Type: entity.NodeTypeEntry, Configs: &entry.Config{}},
503+
lambdaNode,
504+
{Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, Configs: &exit.Config{TerminatePlan: vo.ReturnVariables},
505+
InputSources: []*vo.FieldInfo{
506+
{
507+
Path: compose.FieldPath{"output"},
508+
Source: vo.FieldSource{
509+
Ref: &vo.Reference{
510+
FromNodeKey: "lambda",
511+
FromPath: compose.FieldPath{"output"},
512+
},
513+
},
514+
},
515+
},
516+
},
517+
},
518+
Connections: []*schema.Connection{
519+
{FromNode: entity.EntryNodeKey, ToNode: "lambda"},
520+
{FromNode: "lambda", ToNode: entity.ExitNodeKey},
521+
},
522+
}
523+
subWfSchema.Init()
524+
525+
subWfNode := &schema.NodeSchema{
526+
Key: "sub_workflow",
527+
Type: entity.NodeTypeSubWorkflow,
528+
Configs: &subworkflow.Config{WorkflowID: 100},
529+
SubWorkflowSchema: subWfSchema,
530+
SubWorkflowBasic: &entity.WorkflowBasic{ID: 100, Version: "1"},
531+
InputSources: []*vo.FieldInfo{
532+
{
533+
Path: compose.FieldPath{"resume_data"},
534+
Source: vo.FieldSource{
535+
Ref: &vo.Reference{
536+
FromNodeKey: "outer_batch",
537+
FromPath: compose.FieldPath{"outer_array"},
538+
},
539+
},
540+
},
541+
},
542+
OutputSources: []*vo.FieldInfo{
543+
{
544+
Path: compose.FieldPath{"inner_output"},
545+
Source: vo.FieldSource{
546+
Ref: &vo.Reference{
547+
FromNodeKey: entity.ExitNodeKey, // SubWorkflow's exit node
548+
FromPath: compose.FieldPath{"output"},
549+
},
550+
},
551+
},
552+
},
553+
}
554+
555+
outerBatch := &schema.NodeSchema{
556+
Key: "outer_batch",
557+
Type: entity.NodeTypeBatch,
558+
Configs: &batch.Config{},
559+
InputSources: []*vo.FieldInfo{
560+
{
561+
Path: compose.FieldPath{"outer_array"},
562+
Source: vo.FieldSource{
563+
Ref: &vo.Reference{
564+
FromNodeKey: entity.EntryNodeKey,
565+
FromPath: compose.FieldPath{"outer_array"},
566+
},
567+
},
568+
},
569+
{
570+
Path: compose.FieldPath{batch.ConcurrentSizeKey},
571+
Source: vo.FieldSource{Val: int64(2)},
572+
},
573+
{
574+
Path: compose.FieldPath{batch.MaxBatchSizeKey},
575+
Source: vo.FieldSource{Val: int64(2)},
576+
},
577+
},
578+
InputTypes: map[string]*vo.TypeInfo{
579+
"outer_array": {
580+
Type: vo.DataTypeArray,
581+
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString},
582+
},
583+
},
584+
OutputSources: []*vo.FieldInfo{
585+
{
586+
Path: compose.FieldPath{"final_output"},
587+
Source: vo.FieldSource{
588+
Ref: &vo.Reference{
589+
FromNodeKey: "sub_workflow",
590+
FromPath: compose.FieldPath{"output"},
591+
},
592+
},
593+
},
594+
},
595+
}
596+
597+
ws := &schema.WorkflowSchema{
598+
Nodes: []*schema.NodeSchema{
599+
{Key: entity.EntryNodeKey, Type: entity.NodeTypeEntry, Configs: &entry.Config{}},
600+
outerBatch,
601+
subWfNode,
602+
{Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, Configs: &exit.Config{TerminatePlan: vo.ReturnVariables}},
603+
},
604+
Hierarchy: map[vo.NodeKey]vo.NodeKey{
605+
"sub_workflow": "outer_batch",
606+
},
607+
Connections: []*schema.Connection{
608+
{FromNode: entity.EntryNodeKey, ToNode: "outer_batch"},
609+
{FromNode: "outer_batch", ToNode: "sub_workflow"},
610+
{FromNode: "sub_workflow", ToNode: "outer_batch"},
611+
{FromNode: "outer_batch", ToNode: entity.ExitNodeKey},
612+
},
613+
}
614+
615+
ws.Init()
616+
basic := &entity.WorkflowBasic{ID: 1, Version: "1"}
617+
618+
// MOCK Repository globally
619+
myRepo := &mockRepo{}
620+
mockPatch := mockey.Mock(workflow.GetRepository).To(func() workflow.Repository {
621+
return myRepo
622+
}).Build()
623+
defer mockPatch.UnPatch()
624+
625+
// 1. Initial run: 2 items in outer array, both should interrupt
626+
initialRunner := compose2.NewWorkflowRunner(basic, ws, model.ExecuteConfig{})
627+
initialCtx, executeID, opts, _, err := initialRunner.Prepare(ctx)
628+
assert.NoError(t, err)
629+
630+
wf, err := compose2.NewWorkflow(initialCtx, ws, compose2.WithIDAsName(basic.ID))
631+
assert.NoError(t, err)
632+
633+
_, err = wf.Runner.Invoke(initialCtx, map[string]any{
634+
"outer_array": []any{"a", "b"},
635+
"resume_data": nil,
636+
}, opts...)
637+
638+
t.Logf("Initial run error: %v", err)
639+
assert.Error(t, err)
640+
if callCount != 2 {
641+
t.Fatalf("Expected callCount 2, got %d. Error was: %v", callCount, err)
642+
}
643+
644+
// 2. Resume the first event returned by GetFirstInterruptEvent
645+
repo := workflow.GetRepository()
646+
event0, _, _ := repo.GetFirstInterruptEvent(ctx, executeID)
647+
assert.NotNil(t, event0, "Event 0 should not be nil")
648+
if event0 == nil {
649+
t.Fatal("Event0 is nil, cannot proceed")
650+
}
651+
652+
// Create a new runner for resumption
653+
resumeRunner := compose2.NewWorkflowRunner(basic, ws, model.ExecuteConfig{},
654+
compose2.WithResumeReq(&entity.ResumeRequest{
655+
ExecuteID: executeID,
656+
EventID: event0.ID,
657+
ResumeData: "resumed",
658+
}))
659+
660+
resumeCtx, _, resumeOpts, _, err := resumeRunner.Prepare(ctx)
661+
assert.NoError(t, err)
662+
663+
// Invoke resumption
664+
_, err = wf.Runner.Invoke(resumeCtx, map[string]any{
665+
"outer_array": []any{nil, nil},
666+
"resume_data": "resumed",
667+
}, resumeOpts...)
668+
669+
assert.Error(t, err) // Should still be interrupted at index 1
670+
671+
// CRITICAL ASSERTION:
672+
// If the fix works, callCount should be 3 (2 from initial + 1 from resumed index 0).
673+
// If index 1 was NOT skipped (the bug), callCount would be 4.
674+
assert.Equal(t, 3, callCount, "Index 1 of the outer batch should have been skipped")
675+
}

0 commit comments

Comments
 (0)