From b751fd1aab7ea509c386a643cbbee145f5a606ac Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Fri, 20 Mar 2026 19:52:19 +0900 Subject: [PATCH 1/5] proxy: honor blocking command timeouts --- proxy/backend.go | 7 +++ proxy/blocking.go | 54 ++++++++++++++++++++++ proxy/dualwrite.go | 8 +++- proxy/proxy_test.go | 108 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 176 insertions(+), 1 deletion(-) create mode 100644 proxy/blocking.go diff --git a/proxy/backend.go b/proxy/backend.go index 9960cee1..c61dec31 100644 --- a/proxy/backend.go +++ b/proxy/backend.go @@ -85,6 +85,13 @@ func (b *RedisBackend) Do(ctx context.Context, args ...any) *redis.Cmd { return b.client.Do(ctx, args...) } +// DoWithTimeout executes a command using a per-call socket timeout override. +// This is used for blocking commands whose wait time exceeds the backend's +// default read timeout. +func (b *RedisBackend) DoWithTimeout(ctx context.Context, timeout time.Duration, args ...any) *redis.Cmd { + return b.client.WithTimeout(timeout).Do(ctx, args...) +} + func (b *RedisBackend) Pipeline(ctx context.Context, cmds [][]any) ([]*redis.Cmd, error) { pipe := b.client.Pipeline() results := make([]*redis.Cmd, len(cmds)) diff --git a/proxy/blocking.go b/proxy/blocking.go new file mode 100644 index 00000000..18a42a4e --- /dev/null +++ b/proxy/blocking.go @@ -0,0 +1,54 @@ +package proxy + +import ( + "context" + "strconv" + "strings" + "time" + + "github.com/redis/go-redis/v9" +) + +const blockingMultiPopMinArgs = 2 + +type blockingTimeoutBackend interface { + DoWithTimeout(ctx context.Context, timeout time.Duration, args ...any) *redis.Cmd +} + +func blockingCommandTimeout(cmd string, args [][]byte) time.Duration { + switch strings.ToUpper(cmd) { + case "BLPOP", "BRPOP", "BRPOPLPUSH", "BLMOVE", "BZPOPMIN", "BZPOPMAX": + if len(args) == 0 { + return 0 + } + return parseBlockingSecondsArg(args[len(args)-1]) + case "BLMPOP": + if len(args) < blockingMultiPopMinArgs { + return 0 + } + return parseBlockingSecondsArg(args[1]) + case "XREAD", "XREADGROUP": + for i := 1; i+1 < len(args); i++ { + if strings.EqualFold(string(args[i]), "BLOCK") { + return parseBlockingMillisecondsArg(args[i+1]) + } + } + } + return 0 +} + +func parseBlockingSecondsArg(raw []byte) time.Duration { + seconds, err := strconv.ParseFloat(string(raw), 64) + if err != nil || seconds < 0 { + return 0 + } + return time.Duration(seconds * float64(time.Second)) +} + +func parseBlockingMillisecondsArg(raw []byte) time.Duration { + millis, err := strconv.ParseInt(string(raw), 10, 64) + if err != nil || millis < 0 { + return 0 + } + return time.Duration(millis) * time.Millisecond +} diff --git a/proxy/dualwrite.go b/proxy/dualwrite.go index 41922d8c..9b8248ca 100644 --- a/proxy/dualwrite.go +++ b/proxy/dualwrite.go @@ -133,9 +133,15 @@ func (d *DualWriter) Read(ctx context.Context, cmd string, args [][]byte) (any, // cmd must be the pre-uppercased command name. func (d *DualWriter) Blocking(ctx context.Context, cmd string, args [][]byte) (any, error) { iArgs := bytesArgsToInterfaces(args) + timeout := blockingCommandTimeout(cmd, args) start := time.Now() - result := d.primary.Do(ctx, iArgs...) + var result *redis.Cmd + if blockingBackend, ok := d.primary.(blockingTimeoutBackend); ok { + result = blockingBackend.DoWithTimeout(ctx, timeout, iArgs...) + } else { + result = d.primary.Do(ctx, iArgs...) + } resp, err := result.Result() d.metrics.CommandDuration.WithLabelValues(cmd, d.primary.Name()).Observe(time.Since(start).Seconds()) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 7e8efb3c..562feeb0 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -478,6 +478,114 @@ func TestDualWriter_Read_NoShadowInDualWrite(t *testing.T) { assert.Equal(t, 0, secondary.CallCount(), "no shadow in dual-write mode") } +type timeoutCapturingBackend struct { + name string + timeout time.Duration + args []any + doCalls int + doWithCalls int + returnValue any + returnErr error +} + +func (b *timeoutCapturingBackend) Do(ctx context.Context, args ...any) *redis.Cmd { + b.doCalls++ + b.args = append([]any(nil), args...) + cmd := redis.NewCmd(ctx, args...) + if b.returnErr != nil { + cmd.SetErr(b.returnErr) + return cmd + } + cmd.SetVal(b.returnValue) + return cmd +} + +func (b *timeoutCapturingBackend) DoWithTimeout(ctx context.Context, timeout time.Duration, args ...any) *redis.Cmd { + b.doWithCalls++ + b.timeout = timeout + b.args = append([]any(nil), args...) + cmd := redis.NewCmd(ctx, args...) + if b.returnErr != nil { + cmd.SetErr(b.returnErr) + return cmd + } + cmd.SetVal(b.returnValue) + return cmd +} + +func (b *timeoutCapturingBackend) Pipeline(ctx context.Context, cmds [][]any) ([]*redis.Cmd, error) { + results := make([]*redis.Cmd, len(cmds)) + for i, args := range cmds { + results[i] = b.Do(ctx, args...) + } + return results, nil +} + +func (b *timeoutCapturingBackend) Close() error { return nil } +func (b *timeoutCapturingBackend) Name() string { return b.name } + +func TestBlockingCommandTimeout(t *testing.T) { + tests := []struct { + name string + cmd string + args [][]byte + expected time.Duration + }{ + { + name: "BZPOPMIN seconds", + cmd: "BZPOPMIN", + args: [][]byte{[]byte("BZPOPMIN"), []byte("queue"), []byte("5")}, + expected: 5 * time.Second, + }, + { + name: "BLMOVE float seconds", + cmd: "BLMOVE", + args: [][]byte{[]byte("BLMOVE"), []byte("src"), []byte("dst"), []byte("LEFT"), []byte("RIGHT"), []byte("2.5")}, + expected: 2500 * time.Millisecond, + }, + { + name: "XREAD block milliseconds", + cmd: "XREAD", + args: [][]byte{[]byte("XREAD"), []byte("BLOCK"), []byte("1500"), []byte("STREAMS"), []byte("jobs"), []byte("0")}, + expected: 1500 * time.Millisecond, + }, + { + name: "XREADGROUP block zero", + cmd: "XREADGROUP", + args: [][]byte{[]byte("XREADGROUP"), []byte("GROUP"), []byte("g"), []byte("c"), []byte("BLOCK"), []byte("0"), []byte("STREAMS"), []byte("jobs"), []byte(">")}, + expected: 0, + }, + { + name: "missing block falls back to zero", + cmd: "XREAD", + args: [][]byte{[]byte("XREAD"), []byte("STREAMS"), []byte("jobs"), []byte("0")}, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, blockingCommandTimeout(tt.cmd, tt.args)) + }) + } +} + +func TestDualWriter_Blocking_UsesTimeoutAwareBackend(t *testing.T) { + primary := &timeoutCapturingBackend{name: "primary", returnValue: "OK"} + secondary := newMockBackend("secondary") + + metrics := newTestMetrics() + d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeRedisOnly}, metrics, newTestSentry(), testLogger) + + resp, err := d.Blocking(context.Background(), "BZPOPMIN", [][]byte{[]byte("BZPOPMIN"), []byte("queue"), []byte("5")}) + assert.NoError(t, err) + assert.Equal(t, "OK", resp) + assert.Equal(t, 0, primary.doCalls) + assert.Equal(t, 1, primary.doWithCalls) + assert.Equal(t, 5*time.Second, primary.timeout) + assert.Equal(t, []any{[]byte("BZPOPMIN"), []byte("queue"), []byte("5")}, primary.args) +} + func TestDualWriter_GoAsync_Bounded(t *testing.T) { primary := newMockBackend("primary") primary.doFunc = makeCmd("OK", nil) From 7f3f9176176e6bf313a7354c162f97e63ae3dc20 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Fri, 20 Mar 2026 20:06:39 +0900 Subject: [PATCH 2/5] proxy: reduce secondary replay mismatches --- proxy/backend.go | 2 + proxy/dualwrite.go | 16 ++++-- proxy/proxy_test.go | 48 ++++++++++++++++++ proxy/script_cache.go | 114 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 177 insertions(+), 3 deletions(-) create mode 100644 proxy/script_cache.go diff --git a/proxy/backend.go b/proxy/backend.go index c61dec31..3a5a7aa9 100644 --- a/proxy/backend.go +++ b/proxy/backend.go @@ -14,6 +14,7 @@ const ( defaultDialTimeout = 5 * time.Second defaultReadTimeout = 3 * time.Second defaultWriteTimeout = 3 * time.Second + respProtocolV2 = 2 ) // Backend abstracts a Redis-protocol endpoint (real Redis or ElasticKV). @@ -72,6 +73,7 @@ func NewRedisBackendWithOptions(addr string, name string, opts BackendOptions) * Addr: addr, DB: opts.DB, Password: opts.Password, + Protocol: respProtocolV2, PoolSize: opts.PoolSize, DialTimeout: opts.DialTimeout, ReadTimeout: opts.ReadTimeout, diff --git a/proxy/dualwrite.go b/proxy/dualwrite.go index 9b8248ca..bc919551 100644 --- a/proxy/dualwrite.go +++ b/proxy/dualwrite.go @@ -32,9 +32,11 @@ type DualWriter struct { writeSem chan struct{} // bounds concurrent secondary write goroutines shadowSem chan struct{} // bounds concurrent shadow read goroutines - wg sync.WaitGroup - mu sync.Mutex // protects closed; held briefly to make wg.Add atomic with close check - closed bool + wg sync.WaitGroup + mu sync.Mutex // protects closed; held briefly to make wg.Add atomic with close check + closed bool + scriptMu sync.RWMutex + scripts map[string]string } // NewDualWriter creates a DualWriter with the given backends. @@ -48,6 +50,7 @@ func NewDualWriter(primary, secondary Backend, cfg ProxyConfig, metrics *ProxyMe logger: logger, writeSem: make(chan struct{}, maxWriteGoroutines), shadowSem: make(chan struct{}, maxShadowGoroutines), + scripts: make(map[string]string), } if cfg.Mode == ModeDualWriteShadow || cfg.Mode == ModeElasticKVPrimary { @@ -190,6 +193,7 @@ func (d *DualWriter) Script(ctx context.Context, cmd string, args [][]byte) (any result := d.primary.Do(ctx, iArgs...) resp, err := result.Result() d.metrics.CommandDuration.WithLabelValues(cmd, d.primary.Name()).Observe(time.Since(start).Seconds()) + d.rememberScript(cmd, args) if err != nil && !errors.Is(err, redis.Nil) { d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "error").Inc() @@ -211,6 +215,12 @@ func (d *DualWriter) writeSecondary(cmd string, iArgs []any) { start := time.Now() result := d.secondary.Do(sCtx, iArgs...) _, sErr := result.Result() + if isNoScriptError(sErr) { + if fallbackArgs, ok := d.evalFallbackArgs(cmd, iArgs); ok { + result = d.secondary.Do(sCtx, fallbackArgs...) + _, sErr = result.Result() + } + } d.metrics.CommandDuration.WithLabelValues(cmd, d.secondary.Name()).Observe(time.Since(start).Seconds()) if sErr != nil && !errors.Is(sErr, redis.Nil) { diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 562feeb0..aa26fc5a 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -768,6 +768,15 @@ func TestDefaultBackendOptions(t *testing.T) { assert.Equal(t, 5*time.Second, opts.DialTimeout) } +func TestNewRedisBackend_UsesRESP2(t *testing.T) { + backend := NewRedisBackend("127.0.0.1:6379", "test") + t.Cleanup(func() { + assert.NoError(t, backend.Close()) + }) + + assert.Equal(t, respProtocolV2, backend.client.Options().Protocol) +} + // ========== Pipeline error handling tests ========== func TestPipeline_TransportError(t *testing.T) { @@ -781,6 +790,45 @@ func TestPipeline_TransportError(t *testing.T) { assert.Len(t, results, 3) } +func TestDualWriter_Script_CachesEvalForEvalSHAFallback(t *testing.T) { + primary := newMockBackend("primary") + primary.doFunc = makeCmd("OK", nil) + + secondary := newMockBackend("secondary") + script := "return ARGV[1]" + sha := scriptSHA(script) + var calls int + secondary.doFunc = func(ctx context.Context, args ...any) *redis.Cmd { + calls++ + cmd := redis.NewCmd(ctx, args...) + switch calls { + case 1: + assert.Equal(t, []byte("EVALSHA"), args[0]) + assert.Equal(t, []byte(sha), args[1]) + cmd.SetErr(testRedisErr("NOSCRIPT No matching script. Please use EVAL.")) + case 2: + assert.Equal(t, []byte("EVAL"), args[0]) + assert.Equal(t, []byte(script), args[1]) + cmd.SetVal("OK") + default: + t.Fatalf("unexpected secondary call %d", calls) + } + return cmd + } + + metrics := newTestMetrics() + d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeRedisOnly, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger) + + _, err := d.Script(context.Background(), "EVAL", [][]byte{[]byte("EVAL"), []byte(script), []byte("0"), []byte("value")}) + assert.NoError(t, err) + + d.cfg.Mode = ModeDualWrite + d.writeSecondary("EVALSHA", []any{[]byte("EVALSHA"), []byte(sha), []byte("0"), []byte("value")}) + + assert.Equal(t, 2, calls) + assert.InDelta(t, 0, testutil.ToFloat64(metrics.SecondaryWriteErrors), 0.001) +} + // ========== writeRedisValue tests ========== // testRedisErr satisfies the redis.Error interface for testing. diff --git a/proxy/script_cache.go b/proxy/script_cache.go new file mode 100644 index 00000000..6dd11dce --- /dev/null +++ b/proxy/script_cache.go @@ -0,0 +1,114 @@ +package proxy + +import ( + "crypto/sha1" // #nosec G505 -- Redis EVALSHA specifies SHA1 script digests. + "encoding/hex" + "errors" + "strings" + + "github.com/redis/go-redis/v9" +) + +const ( + cmdEval = "EVAL" + cmdEvalSHA = "EVALSHA" + cmdScript = "SCRIPT" + + minScriptSubcommandArgs = 2 + scriptLoadArgIndex = 2 + minEvalSHAArgs = 2 +) + +func (d *DualWriter) rememberScript(cmd string, args [][]byte) { + upper := strings.ToUpper(cmd) + + switch upper { + case cmdEval, "EVAL_RO": + if len(args) > 1 { + d.storeScript(string(args[1])) + } + case cmdScript: + if len(args) < minScriptSubcommandArgs { + return + } + switch strings.ToUpper(string(args[1])) { + case "LOAD": + if len(args) > scriptLoadArgIndex { + d.storeScript(string(args[scriptLoadArgIndex])) + } + case "FLUSH": + d.clearScripts() + } + } +} + +func (d *DualWriter) storeScript(script string) { + sha := scriptSHA(script) + + d.scriptMu.Lock() + defer d.scriptMu.Unlock() + d.scripts[sha] = script +} + +func (d *DualWriter) clearScripts() { + d.scriptMu.Lock() + defer d.scriptMu.Unlock() + clear(d.scripts) +} + +func (d *DualWriter) lookupScript(sha string) (string, bool) { + d.scriptMu.RLock() + defer d.scriptMu.RUnlock() + script, ok := d.scripts[strings.ToLower(sha)] + return script, ok +} + +func (d *DualWriter) evalFallbackArgs(cmd string, iArgs []any) ([]any, bool) { + upper := strings.ToUpper(cmd) + if upper != cmdEvalSHA && upper != "EVALSHA_RO" { + return nil, false + } + if len(iArgs) < minEvalSHAArgs { + return nil, false + } + + sha := stringArg(iArgs[1]) + script, ok := d.lookupScript(sha) + if !ok { + return nil, false + } + + fallback := make([]any, len(iArgs)) + fallback[0] = []byte(cmdEval) + fallback[1] = []byte(script) + copy(fallback[2:], iArgs[2:]) + return fallback, true +} + +func isNoScriptError(err error) bool { + if err == nil { + return false + } + var redisErr redis.Error + if errors.As(err, &redisErr) { + return strings.HasPrefix(redisErr.Error(), "NOSCRIPT ") + } + return strings.HasPrefix(err.Error(), "NOSCRIPT ") +} + +func scriptSHA(script string) string { + // #nosec G401 -- Redis EVALSHA uses SHA1 digests by protocol. + sum := sha1.Sum([]byte(script)) + return hex.EncodeToString(sum[:]) +} + +func stringArg(arg any) string { + switch v := arg.(type) { + case []byte: + return strings.ToLower(string(v)) + case string: + return strings.ToLower(v) + default: + return strings.ToLower(string(argsToBytes([]any{arg})[0])) + } +} From 969c7b7d58479334fc1c1cf422d08ccc9826ebd7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Mar 2026 11:39:05 +0000 Subject: [PATCH 3/5] Initial plan From e8d532ceee9fef374e981ace4dd027d092d71dc3 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Fri, 20 Mar 2026 20:47:34 +0900 Subject: [PATCH 4/5] Raise gRPC message size limits --- adapter/test_util.go | 8 +++----- cmd/server/demo.go | 9 +++------ internal/grpc.go | 29 +++++++++++++++++++++++++++++ kv/grpc_conn_cache.go | 8 ++++---- main.go | 2 +- multiraft_runtime.go | 7 ++----- 6 files changed, 42 insertions(+), 21 deletions(-) create mode 100644 internal/grpc.go diff --git a/adapter/test_util.go b/adapter/test_util.go index 3eb3adf1..b4d41845 100644 --- a/adapter/test_util.go +++ b/adapter/test_util.go @@ -13,6 +13,7 @@ import ( "github.com/Jille/raft-grpc-leader-rpc/leaderhealth" transport "github.com/Jille/raft-grpc-transport" "github.com/Jille/raftadmin" + internalutil "github.com/bootjp/elastickv/internal" "github.com/bootjp/elastickv/kv" pb "github.com/bootjp/elastickv/proto" "github.com/bootjp/elastickv/store" @@ -23,7 +24,6 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/sys/unix" "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" ) func shutdown(nodes []Node) { @@ -350,7 +350,7 @@ func setupNodes(t *testing.T, ctx context.Context, n int, ports []portsAdress) ( r, tm, err := newRaft(strconv.Itoa(i), port.raftAddress, fsm, i == 0, cfg, electionTimeout) assert.NoError(t, err) - s := grpc.NewServer() + s := grpc.NewServer(internalutil.GRPCServerOptions()...) trx := kv.NewTransaction(r) coordinator := kv.NewCoordinator(trx, r) relay := NewRedisPubSubRelay() @@ -416,9 +416,7 @@ func newRaft(myID string, myAddress string, fsm raft.FSM, bootstrap bool, cfg ra Level: hclog.LevelFromString("WARN"), }) - tm := transport.New(raft.ServerAddress(myAddress), []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), - }) + tm := transport.New(raft.ServerAddress(myAddress), internalutil.GRPCDialOptions()) r, err := raft.NewRaft(c, fsm, ldb, sdb, fss, tm.Transport()) if err != nil { diff --git a/cmd/server/demo.go b/cmd/server/demo.go index 2853bf61..8daba62c 100644 --- a/cmd/server/demo.go +++ b/cmd/server/demo.go @@ -30,7 +30,6 @@ import ( "github.com/hashicorp/raft" "golang.org/x/sync/errgroup" "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" ) var ( @@ -206,7 +205,7 @@ func joinCluster(ctx context.Context, nodes []config) error { } // Connect to leader - conn, err := grpc.NewClient(leader.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + conn, err := grpc.NewClient(leader.address, internalutil.GRPCDialOptions()...) if err != nil { return fmt.Errorf("failed to dial leader: %w", err) } @@ -322,7 +321,7 @@ func setupStorage(dir string) (raft.LogStore, raft.StableStore, raft.SnapshotSto } func setupGRPC(r *raft.Raft, st store.MVCCStore, tm *transport.Manager, coordinator *kv.Coordinate, distServer *adapter.DistributionServer, relay *adapter.RedisPubSubRelay) (*grpc.Server, *adapter.GRPCServer) { - s := grpc.NewServer() + s := grpc.NewServer(internalutil.GRPCServerOptions()...) trx := kv.NewTransaction(r) routedStore := kv.NewLeaderRoutedStore(st, coordinator) gs := adapter.NewGRPCServer(routedStore, coordinator, adapter.WithCloseStore()) @@ -381,9 +380,7 @@ func run(ctx context.Context, eg *errgroup.Group, cfg config) error { }) // Transport - tm := transport.New(raft.ServerAddress(cfg.address), []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), - }) + tm := transport.New(raft.ServerAddress(cfg.address), internalutil.GRPCDialOptions()) r, err := raft.NewRaft(c, fsm, ldb, sdb, fss, tm.Transport()) if err != nil { diff --git a/internal/grpc.go b/internal/grpc.go new file mode 100644 index 00000000..59092615 --- /dev/null +++ b/internal/grpc.go @@ -0,0 +1,29 @@ +package internal + +import ( + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +const GRPCMaxMessageBytes = 64 << 20 + +// GRPCServerOptions keeps Raft replication and the public/internal APIs aligned +// on the same message-size budget. +func GRPCServerOptions() []grpc.ServerOption { + return []grpc.ServerOption{ + grpc.MaxRecvMsgSize(GRPCMaxMessageBytes), + grpc.MaxSendMsgSize(GRPCMaxMessageBytes), + } +} + +// GRPCDialOptions returns the common insecure dial options used by node-local +// and node-to-node traffic. +func GRPCDialOptions() []grpc.DialOption { + return []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(GRPCMaxMessageBytes), + grpc.MaxCallSendMsgSize(GRPCMaxMessageBytes), + ), + } +} diff --git a/kv/grpc_conn_cache.go b/kv/grpc_conn_cache.go index fbca3d07..819806f4 100644 --- a/kv/grpc_conn_cache.go +++ b/kv/grpc_conn_cache.go @@ -3,11 +3,11 @@ package kv import ( "sync" + internalutil "github.com/bootjp/elastickv/internal" "github.com/cockroachdb/errors" "github.com/hashicorp/raft" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/credentials/insecure" ) // GRPCConnCache reuses gRPC connections per address. gRPC itself handles @@ -75,9 +75,9 @@ func (c *GRPCConnCache) ConnFor(addr raft.ServerAddress) (*grpc.ClientConn, erro return conn, nil } - conn, err := grpc.NewClient(string(addr), - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithDefaultCallOptions(grpc.WaitForReady(true)), + conn, err := grpc.NewClient( + string(addr), + append(internalutil.GRPCDialOptions(), grpc.WithDefaultCallOptions(grpc.WaitForReady(true)))..., ) if err != nil { return nil, errors.WithStack(err) diff --git a/main.go b/main.go index ffb35867..02f6d361 100644 --- a/main.go +++ b/main.go @@ -287,7 +287,7 @@ func raftMonitorRuntimes(runtimes []*raftGroupRuntime) []monitoring.RaftRuntime func startRaftServers(ctx context.Context, lc *net.ListenConfig, eg *errgroup.Group, runtimes []*raftGroupRuntime, shardStore *kv.ShardStore, coordinate kv.Coordinator, distServer *adapter.DistributionServer, relay *adapter.RedisPubSubRelay) error { for _, rt := range runtimes { - gs := grpc.NewServer() + gs := grpc.NewServer(internalutil.GRPCServerOptions()...) trx := kv.NewTransaction(rt.raft) grpcSvc := adapter.NewGRPCServer(shardStore, coordinate) pb.RegisterRawKVServer(gs, grpcSvc) diff --git a/multiraft_runtime.go b/multiraft_runtime.go index 22c5de14..d62cb9f5 100644 --- a/multiraft_runtime.go +++ b/multiraft_runtime.go @@ -7,12 +7,11 @@ import ( "time" transport "github.com/Jille/raft-grpc-transport" + internalutil "github.com/bootjp/elastickv/internal" "github.com/bootjp/elastickv/internal/raftstore" "github.com/bootjp/elastickv/store" "github.com/cockroachdb/errors" "github.com/hashicorp/raft" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" ) type raftGroupRuntime struct { @@ -117,9 +116,7 @@ func newRaftGroup(raftID string, group groupSpec, baseDir string, multi bool, bo return nil, nil, nil, errors.WithStack(err) } - tm = transport.New(raft.ServerAddress(group.address), []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), - }) + tm = transport.New(raft.ServerAddress(group.address), internalutil.GRPCDialOptions()) r, err := raft.NewRaft(c, fsm, raftStore, raftStore, fss, tm.Transport()) if err != nil { From f26bcd3a62079d4f003e588bf0c51fdc1d46ac5e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Mar 2026 11:48:09 +0000 Subject: [PATCH 5/5] proxy: fix EVALSHA_RO fallback, guard script cache on primary error, add bounded eviction Co-authored-by: bootjp <1306365+bootjp@users.noreply.github.com> --- go.mod | 2 +- proxy/dualwrite.go | 4 +- proxy/proxy_test.go | 102 ++++++++++++++++++++++++++++++++++++++++++ proxy/script_cache.go | 42 ++++++++++++++--- 4 files changed, 142 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index 5e024a17..47457faa 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/emirpasic/gods v1.18.1 github.com/getsentry/sentry-go v0.27.0 github.com/hashicorp/go-hclog v1.6.3 + github.com/hashicorp/go-msgpack/v2 v2.1.2 github.com/hashicorp/raft v1.7.3 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.23.2 @@ -66,7 +67,6 @@ require ( github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/go-immutable-radix v1.3.1 // indirect github.com/hashicorp/go-metrics v0.5.4 // indirect - github.com/hashicorp/go-msgpack/v2 v2.1.2 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/golang-lru v1.0.2 // indirect github.com/klauspost/compress v1.18.0 // indirect diff --git a/proxy/dualwrite.go b/proxy/dualwrite.go index bc919551..9da223c7 100644 --- a/proxy/dualwrite.go +++ b/proxy/dualwrite.go @@ -37,6 +37,8 @@ type DualWriter struct { closed bool scriptMu sync.RWMutex scripts map[string]string + // scriptOrder tracks insertion order for FIFO eviction of the bounded script cache. + scriptOrder []string } // NewDualWriter creates a DualWriter with the given backends. @@ -193,13 +195,13 @@ func (d *DualWriter) Script(ctx context.Context, cmd string, args [][]byte) (any result := d.primary.Do(ctx, iArgs...) resp, err := result.Result() d.metrics.CommandDuration.WithLabelValues(cmd, d.primary.Name()).Observe(time.Since(start).Seconds()) - d.rememberScript(cmd, args) if err != nil && !errors.Is(err, redis.Nil) { d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "error").Inc() return nil, fmt.Errorf("primary script %s: %w", cmd, err) } d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "ok").Inc() + d.rememberScript(cmd, args) if d.hasSecondaryWrite() { d.goWrite(func() { d.writeSecondary(cmd, iArgs) }) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index aa26fc5a..87dce0e8 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -3,6 +3,7 @@ package proxy import ( "context" "errors" + "fmt" "io" "log/slog" "sync" @@ -829,6 +830,107 @@ func TestDualWriter_Script_CachesEvalForEvalSHAFallback(t *testing.T) { assert.InDelta(t, 0, testutil.ToFloat64(metrics.SecondaryWriteErrors), 0.001) } +func TestDualWriter_Script_EvalSHARO_FallsBackToEvalRO(t *testing.T) { + primary := newMockBackend("primary") + primary.doFunc = makeCmd("OK", nil) + + secondary := newMockBackend("secondary") + script := "return KEYS[1]" + sha := scriptSHA(script) + var calls int + secondary.doFunc = func(ctx context.Context, args ...any) *redis.Cmd { + calls++ + cmd := redis.NewCmd(ctx, args...) + switch calls { + case 1: + assert.Equal(t, []byte("EVALSHA_RO"), args[0]) + cmd.SetErr(testRedisErr("NOSCRIPT No matching script. Please use EVAL.")) + case 2: + // Must fall back to EVAL_RO, not EVAL, to preserve read-only semantics. + assert.Equal(t, []byte("EVAL_RO"), args[0]) + assert.Equal(t, []byte(script), args[1]) + cmd.SetVal("mykey") + default: + t.Fatalf("unexpected secondary call %d", calls) + } + return cmd + } + + metrics := newTestMetrics() + d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeRedisOnly, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger) + + // Register the script body via EVAL_RO so the proxy can fall back. + _, err := d.Script(context.Background(), "EVAL_RO", [][]byte{[]byte("EVAL_RO"), []byte(script), []byte("1"), []byte("mykey")}) + assert.NoError(t, err) + + d.cfg.Mode = ModeDualWrite + d.writeSecondary("EVALSHA_RO", []any{[]byte("EVALSHA_RO"), []byte(sha), []byte("1"), []byte("mykey")}) + + assert.Equal(t, 2, calls) + assert.InDelta(t, 0, testutil.ToFloat64(metrics.SecondaryWriteErrors), 0.001) +} + +func TestDualWriter_Script_NoRememberOnPrimaryError(t *testing.T) { + // Verify that a failed SCRIPT FLUSH on the primary does NOT clear the proxy + // script cache, so that subsequent EVALSHA → EVAL fallbacks still work. + primary := newMockBackend("primary") + primary.doFunc = makeCmd(nil, testRedisErr("ERR flush failed")) + + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd("OK", nil) + + metrics := newTestMetrics() + d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeRedisOnly, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger) + + // Seed the script cache directly. + script := "return 1" + d.storeScript(script) + sha := scriptSHA(script) + _, cached := d.lookupScript(sha) + assert.True(t, cached, "script should be cached before flush attempt") + + // Attempt SCRIPT FLUSH — primary returns an error so the cache must be untouched. + _, err := d.Script(context.Background(), "SCRIPT", [][]byte{[]byte("SCRIPT"), []byte("FLUSH")}) + assert.Error(t, err) + + _, stillCached := d.lookupScript(sha) + assert.True(t, stillCached, "cache must not be cleared when primary SCRIPT FLUSH fails") +} + +func TestDualWriter_ScriptCache_BoundedEviction(t *testing.T) { + // Fill the cache beyond maxScriptCacheSize and verify it stays bounded. + primary := newMockBackend("primary") + primary.doFunc = makeCmd("OK", nil) + + metrics := newTestMetrics() + d := NewDualWriter(primary, nil, ProxyConfig{Mode: ModeRedisOnly, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger) + + // Insert maxScriptCacheSize+10 unique scripts. + total := maxScriptCacheSize + 10 + for i := range total { + d.storeScript(fmt.Sprintf("return %d", i)) + } + + d.scriptMu.RLock() + size := len(d.scripts) + d.scriptMu.RUnlock() + + assert.Equal(t, maxScriptCacheSize, size, "cache must not exceed maxScriptCacheSize") + + // The first 10 scripts (insertion order) must have been evicted. + for i := range 10 { + sha := scriptSHA(fmt.Sprintf("return %d", i)) + _, ok := d.lookupScript(sha) + assert.False(t, ok, "script %d should have been evicted", i) + } + // The last maxScriptCacheSize scripts must still be present. + for i := 10; i < total; i++ { + sha := scriptSHA(fmt.Sprintf("return %d", i)) + _, ok := d.lookupScript(sha) + assert.True(t, ok, "script %d should still be cached", i) + } +} + // ========== writeRedisValue tests ========== // testRedisErr satisfies the redis.Error interface for testing. diff --git a/proxy/script_cache.go b/proxy/script_cache.go index 6dd11dce..96d1be03 100644 --- a/proxy/script_cache.go +++ b/proxy/script_cache.go @@ -10,20 +10,27 @@ import ( ) const ( - cmdEval = "EVAL" - cmdEvalSHA = "EVALSHA" - cmdScript = "SCRIPT" + cmdEval = "EVAL" + cmdEvalRO = "EVAL_RO" + cmdEvalSHA = "EVALSHA" + cmdEvalSHARO = "EVALSHA_RO" + cmdScript = "SCRIPT" minScriptSubcommandArgs = 2 scriptLoadArgIndex = 2 minEvalSHAArgs = 2 + + // maxScriptCacheSize is the maximum number of scripts retained in the + // proxy-side script cache. When the limit is reached, the oldest entry + // (by insertion order) is evicted to prevent unbounded memory growth. + maxScriptCacheSize = 512 ) func (d *DualWriter) rememberScript(cmd string, args [][]byte) { upper := strings.ToUpper(cmd) switch upper { - case cmdEval, "EVAL_RO": + case cmdEval, cmdEvalRO: if len(args) > 1 { d.storeScript(string(args[1])) } @@ -47,6 +54,22 @@ func (d *DualWriter) storeScript(script string) { d.scriptMu.Lock() defer d.scriptMu.Unlock() + + if _, exists := d.scripts[sha]; !exists { + // Evict the oldest entry when at capacity. + // scriptOrder and scripts always have the same length, so the guard is + // sufficient; the additional len(scriptOrder) check is not needed. + if len(d.scripts) >= maxScriptCacheSize { + oldest := d.scriptOrder[0] + d.scriptOrder = d.scriptOrder[1:] + delete(d.scripts, oldest) + } + // Append to the end so that eviction follows strict FIFO insertion order. + // Re-storing an already-cached script does not move its position; the + // entry retains its original eviction priority, which is acceptable for + // this use-case because script bodies are immutable (same SHA ⇒ same body). + d.scriptOrder = append(d.scriptOrder, sha) + } d.scripts[sha] = script } @@ -54,6 +77,7 @@ func (d *DualWriter) clearScripts() { d.scriptMu.Lock() defer d.scriptMu.Unlock() clear(d.scripts) + d.scriptOrder = d.scriptOrder[:0] } func (d *DualWriter) lookupScript(sha string) (string, bool) { @@ -65,7 +89,7 @@ func (d *DualWriter) lookupScript(sha string) (string, bool) { func (d *DualWriter) evalFallbackArgs(cmd string, iArgs []any) ([]any, bool) { upper := strings.ToUpper(cmd) - if upper != cmdEvalSHA && upper != "EVALSHA_RO" { + if upper != cmdEvalSHA && upper != cmdEvalSHARO { return nil, false } if len(iArgs) < minEvalSHAArgs { @@ -78,8 +102,14 @@ func (d *DualWriter) evalFallbackArgs(cmd string, iArgs []any) ([]any, bool) { return nil, false } + // Preserve the read-only semantics: EVALSHA_RO falls back to EVAL_RO. + fallbackCmd := cmdEval + if upper == cmdEvalSHARO { + fallbackCmd = cmdEvalRO + } + fallback := make([]any, len(iArgs)) - fallback[0] = []byte(cmdEval) + fallback[0] = []byte(fallbackCmd) fallback[1] = []byte(script) copy(fallback[2:], iArgs[2:]) return fallback, true