Skip to content

Commit b751fd1

Browse files
committed
proxy: honor blocking command timeouts
1 parent 148f5b3 commit b751fd1

File tree

4 files changed

+176
-1
lines changed

4 files changed

+176
-1
lines changed

proxy/backend.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ func (b *RedisBackend) Do(ctx context.Context, args ...any) *redis.Cmd {
8585
return b.client.Do(ctx, args...)
8686
}
8787

88+
// DoWithTimeout executes a command using a per-call socket timeout override.
89+
// This is used for blocking commands whose wait time exceeds the backend's
90+
// default read timeout.
91+
func (b *RedisBackend) DoWithTimeout(ctx context.Context, timeout time.Duration, args ...any) *redis.Cmd {
92+
return b.client.WithTimeout(timeout).Do(ctx, args...)
93+
}
94+
8895
func (b *RedisBackend) Pipeline(ctx context.Context, cmds [][]any) ([]*redis.Cmd, error) {
8996
pipe := b.client.Pipeline()
9097
results := make([]*redis.Cmd, len(cmds))

proxy/blocking.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package proxy
2+
3+
import (
4+
"context"
5+
"strconv"
6+
"strings"
7+
"time"
8+
9+
"github.com/redis/go-redis/v9"
10+
)
11+
12+
const blockingMultiPopMinArgs = 2
13+
14+
type blockingTimeoutBackend interface {
15+
DoWithTimeout(ctx context.Context, timeout time.Duration, args ...any) *redis.Cmd
16+
}
17+
18+
func blockingCommandTimeout(cmd string, args [][]byte) time.Duration {
19+
switch strings.ToUpper(cmd) {
20+
case "BLPOP", "BRPOP", "BRPOPLPUSH", "BLMOVE", "BZPOPMIN", "BZPOPMAX":
21+
if len(args) == 0 {
22+
return 0
23+
}
24+
return parseBlockingSecondsArg(args[len(args)-1])
25+
case "BLMPOP":
26+
if len(args) < blockingMultiPopMinArgs {
27+
return 0
28+
}
29+
return parseBlockingSecondsArg(args[1])
30+
case "XREAD", "XREADGROUP":
31+
for i := 1; i+1 < len(args); i++ {
32+
if strings.EqualFold(string(args[i]), "BLOCK") {
33+
return parseBlockingMillisecondsArg(args[i+1])
34+
}
35+
}
36+
}
37+
return 0
38+
}
39+
40+
func parseBlockingSecondsArg(raw []byte) time.Duration {
41+
seconds, err := strconv.ParseFloat(string(raw), 64)
42+
if err != nil || seconds < 0 {
43+
return 0
44+
}
45+
return time.Duration(seconds * float64(time.Second))
46+
}
47+
48+
func parseBlockingMillisecondsArg(raw []byte) time.Duration {
49+
millis, err := strconv.ParseInt(string(raw), 10, 64)
50+
if err != nil || millis < 0 {
51+
return 0
52+
}
53+
return time.Duration(millis) * time.Millisecond
54+
}

proxy/dualwrite.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,15 @@ func (d *DualWriter) Read(ctx context.Context, cmd string, args [][]byte) (any,
133133
// cmd must be the pre-uppercased command name.
134134
func (d *DualWriter) Blocking(ctx context.Context, cmd string, args [][]byte) (any, error) {
135135
iArgs := bytesArgsToInterfaces(args)
136+
timeout := blockingCommandTimeout(cmd, args)
136137

137138
start := time.Now()
138-
result := d.primary.Do(ctx, iArgs...)
139+
var result *redis.Cmd
140+
if blockingBackend, ok := d.primary.(blockingTimeoutBackend); ok {
141+
result = blockingBackend.DoWithTimeout(ctx, timeout, iArgs...)
142+
} else {
143+
result = d.primary.Do(ctx, iArgs...)
144+
}
139145
resp, err := result.Result()
140146
d.metrics.CommandDuration.WithLabelValues(cmd, d.primary.Name()).Observe(time.Since(start).Seconds())
141147

proxy/proxy_test.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,114 @@ func TestDualWriter_Read_NoShadowInDualWrite(t *testing.T) {
478478
assert.Equal(t, 0, secondary.CallCount(), "no shadow in dual-write mode")
479479
}
480480

481+
type timeoutCapturingBackend struct {
482+
name string
483+
timeout time.Duration
484+
args []any
485+
doCalls int
486+
doWithCalls int
487+
returnValue any
488+
returnErr error
489+
}
490+
491+
func (b *timeoutCapturingBackend) Do(ctx context.Context, args ...any) *redis.Cmd {
492+
b.doCalls++
493+
b.args = append([]any(nil), args...)
494+
cmd := redis.NewCmd(ctx, args...)
495+
if b.returnErr != nil {
496+
cmd.SetErr(b.returnErr)
497+
return cmd
498+
}
499+
cmd.SetVal(b.returnValue)
500+
return cmd
501+
}
502+
503+
func (b *timeoutCapturingBackend) DoWithTimeout(ctx context.Context, timeout time.Duration, args ...any) *redis.Cmd {
504+
b.doWithCalls++
505+
b.timeout = timeout
506+
b.args = append([]any(nil), args...)
507+
cmd := redis.NewCmd(ctx, args...)
508+
if b.returnErr != nil {
509+
cmd.SetErr(b.returnErr)
510+
return cmd
511+
}
512+
cmd.SetVal(b.returnValue)
513+
return cmd
514+
}
515+
516+
func (b *timeoutCapturingBackend) Pipeline(ctx context.Context, cmds [][]any) ([]*redis.Cmd, error) {
517+
results := make([]*redis.Cmd, len(cmds))
518+
for i, args := range cmds {
519+
results[i] = b.Do(ctx, args...)
520+
}
521+
return results, nil
522+
}
523+
524+
func (b *timeoutCapturingBackend) Close() error { return nil }
525+
func (b *timeoutCapturingBackend) Name() string { return b.name }
526+
527+
func TestBlockingCommandTimeout(t *testing.T) {
528+
tests := []struct {
529+
name string
530+
cmd string
531+
args [][]byte
532+
expected time.Duration
533+
}{
534+
{
535+
name: "BZPOPMIN seconds",
536+
cmd: "BZPOPMIN",
537+
args: [][]byte{[]byte("BZPOPMIN"), []byte("queue"), []byte("5")},
538+
expected: 5 * time.Second,
539+
},
540+
{
541+
name: "BLMOVE float seconds",
542+
cmd: "BLMOVE",
543+
args: [][]byte{[]byte("BLMOVE"), []byte("src"), []byte("dst"), []byte("LEFT"), []byte("RIGHT"), []byte("2.5")},
544+
expected: 2500 * time.Millisecond,
545+
},
546+
{
547+
name: "XREAD block milliseconds",
548+
cmd: "XREAD",
549+
args: [][]byte{[]byte("XREAD"), []byte("BLOCK"), []byte("1500"), []byte("STREAMS"), []byte("jobs"), []byte("0")},
550+
expected: 1500 * time.Millisecond,
551+
},
552+
{
553+
name: "XREADGROUP block zero",
554+
cmd: "XREADGROUP",
555+
args: [][]byte{[]byte("XREADGROUP"), []byte("GROUP"), []byte("g"), []byte("c"), []byte("BLOCK"), []byte("0"), []byte("STREAMS"), []byte("jobs"), []byte(">")},
556+
expected: 0,
557+
},
558+
{
559+
name: "missing block falls back to zero",
560+
cmd: "XREAD",
561+
args: [][]byte{[]byte("XREAD"), []byte("STREAMS"), []byte("jobs"), []byte("0")},
562+
expected: 0,
563+
},
564+
}
565+
566+
for _, tt := range tests {
567+
t.Run(tt.name, func(t *testing.T) {
568+
assert.Equal(t, tt.expected, blockingCommandTimeout(tt.cmd, tt.args))
569+
})
570+
}
571+
}
572+
573+
func TestDualWriter_Blocking_UsesTimeoutAwareBackend(t *testing.T) {
574+
primary := &timeoutCapturingBackend{name: "primary", returnValue: "OK"}
575+
secondary := newMockBackend("secondary")
576+
577+
metrics := newTestMetrics()
578+
d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeRedisOnly}, metrics, newTestSentry(), testLogger)
579+
580+
resp, err := d.Blocking(context.Background(), "BZPOPMIN", [][]byte{[]byte("BZPOPMIN"), []byte("queue"), []byte("5")})
581+
assert.NoError(t, err)
582+
assert.Equal(t, "OK", resp)
583+
assert.Equal(t, 0, primary.doCalls)
584+
assert.Equal(t, 1, primary.doWithCalls)
585+
assert.Equal(t, 5*time.Second, primary.timeout)
586+
assert.Equal(t, []any{[]byte("BZPOPMIN"), []byte("queue"), []byte("5")}, primary.args)
587+
}
588+
481589
func TestDualWriter_GoAsync_Bounded(t *testing.T) {
482590
primary := newMockBackend("primary")
483591
primary.doFunc = makeCmd("OK", nil)

0 commit comments

Comments
 (0)