Skip to content

Commit 6dd2c2b

Browse files
committed
fix: make inmem bookmarks random for each run
As inmem storage is not persisted, make sure watch bookmark from one run doesn't work with another run. This still allows watches to be restarted on connection failures, but if the watch is restarted on a program restart, bookmark won't match anymore. Signed-off-by: Andrey Smirnov <andrey.smirnov@siderolabs.com>
1 parent f4ff7ab commit 6dd2c2b

File tree

9 files changed

+162
-11
lines changed

9 files changed

+162
-11
lines changed

pkg/state/errors.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,15 @@ func errPhaseConflict(r resource.Reference, expectedPhase resource.Phase) error
137137
},
138138
}
139139
}
140+
141+
// ErrInvalidWatchBookmark should be implemented by "invalid watch bookmark" errors.
142+
type ErrInvalidWatchBookmark interface {
143+
InvalidWatchBookmarkError()
144+
}
145+
146+
// IsInvalidWatchBookmarkError checks if err is invalid watch bookmark.
147+
func IsInvalidWatchBookmarkError(err error) bool {
148+
var i ErrInvalidWatchBookmark
149+
150+
return errors.As(err, &i)
151+
}

pkg/state/impl/inmem/collection.go

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ package inmem
66

77
import (
88
"context"
9+
"crypto/rand"
910
"encoding/binary"
1011
"fmt"
12+
"io"
1113
"slices"
1214
"sort"
1315
"sync"
@@ -265,16 +267,34 @@ func (collection *ResourceCollection) Destroy(ctx context.Context, ptr resource.
265267
return nil
266268
}
267269

270+
// bookmarkCookie is a random cookie used to encode bookmarks.
271+
//
272+
// As the state is in-memory, we need to distinguish between bookmarks from different runs of the program.
273+
var bookmarkCookie = sync.OnceValue(func() []byte {
274+
cookie := make([]byte, 8)
275+
276+
_, err := io.ReadFull(rand.Reader, cookie)
277+
if err != nil {
278+
panic(err)
279+
}
280+
281+
return cookie
282+
})
283+
268284
func encodeBookmark(pos int64) state.Bookmark {
269-
return binary.BigEndian.AppendUint64(nil, uint64(pos))
285+
return binary.BigEndian.AppendUint64(slices.Clone(bookmarkCookie()), uint64(pos))
270286
}
271287

272288
func decodeBookmark(bookmark state.Bookmark) (int64, error) {
273-
if len(bookmark) != 8 {
274-
return 0, fmt.Errorf("invalid bookmark length: %d", len(bookmark))
289+
if len(bookmark) != 16 {
290+
return 0, ErrInvalidWatchBookmark
291+
}
292+
293+
if !slices.Equal(bookmark[:8], bookmarkCookie()) {
294+
return 0, ErrInvalidWatchBookmark
275295
}
276296

277-
return int64(binary.BigEndian.Uint64(bookmark)), nil
297+
return int64(binary.BigEndian.Uint64(bookmark[8:])), nil
278298
}
279299

280300
// Watch for specific resource changes.
@@ -321,7 +341,7 @@ func (collection *ResourceCollection) Watch(ctx context.Context, id resource.ID,
321341
}
322342

323343
if pos < collection.writePos-int64(collection.capacity)+int64(collection.gap) || pos < 0 || pos >= collection.writePos {
324-
return fmt.Errorf("invalid bookmark: %d", pos)
344+
return ErrInvalidWatchBookmark
325345
}
326346

327347
// skip the bookmarked event
@@ -478,7 +498,7 @@ func (collection *ResourceCollection) WatchAll(ctx context.Context, singleCh cha
478498
}
479499

480500
if pos < collection.writePos-int64(collection.capacity)+int64(collection.gap) || pos < -1 || pos >= collection.writePos {
481-
return fmt.Errorf("invalid bookmark: %d", pos)
501+
return ErrInvalidWatchBookmark
482502
}
483503

484504
// skip the bookmarked event

pkg/state/impl/inmem/errors.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package inmem
66

77
import (
8+
"errors"
89
"fmt"
910

1011
"github.com/cosi-project/runtime/pkg/resource"
@@ -100,3 +101,15 @@ func ErrPhaseConflict(r resource.Reference, expectedPhase resource.Phase) error
100101
},
101102
}
102103
}
104+
105+
//nolint:errname
106+
type eInvalidWatchBookmark struct {
107+
error
108+
}
109+
110+
func (eInvalidWatchBookmark) InvalidWatchBookmarkError() {}
111+
112+
// ErrInvalidWatchBookmark generates error compatible with state.ErrInvalidWatchBookmark.
113+
var ErrInvalidWatchBookmark = eInvalidWatchBookmark{
114+
errors.New("invalid watch bookmark"),
115+
}

pkg/state/impl/inmem/errors_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,6 @@ func TestErrors(t *testing.T) {
3232

3333
assert.True(t, state.IsConflictError(inmem.ErrAlreadyExists(resource.NewMetadata("ns", "a", "b", resource.VersionUndefined)), state.WithResourceType("a"), state.WithResourceNamespace("ns")))
3434
assert.False(t, state.IsConflictError(inmem.ErrAlreadyExists(resource.NewMetadata("ns", "a", "b", resource.VersionUndefined)), state.WithResourceType("z"), state.WithResourceNamespace("ns")))
35+
36+
assert.True(t, state.IsInvalidWatchBookmarkError(inmem.ErrInvalidWatchBookmark))
3537
}

pkg/state/impl/inmem/local_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package inmem_test
77
import (
88
"context"
99
"fmt"
10+
"slices"
1011
"strconv"
1112
"testing"
1213
"time"
@@ -199,3 +200,36 @@ func TestNoBufferOverrunDynamic(t *testing.T) {
199200
}
200201
}
201202
}
203+
204+
func TestWatchInvalidBookmark(t *testing.T) {
205+
t.Parallel()
206+
207+
const namespace = "default"
208+
209+
st := state.WrapCore(inmem.NewState(namespace))
210+
211+
ctx, cancel := context.WithCancel(context.Background())
212+
t.Cleanup(cancel)
213+
214+
// start watching for changes
215+
watchKindCh := make(chan state.Event)
216+
217+
err := st.WatchKind(ctx, resource.NewMetadata(namespace, conformance.PathResourceType, "", resource.VersionUndefined), watchKindCh)
218+
require.NoError(t, err)
219+
220+
// insert resource
221+
err = st.Create(ctx, conformance.NewPathResource(namespace, "0"))
222+
require.NoError(t, err)
223+
224+
ev := <-watchKindCh
225+
226+
require.Equal(t, state.Created, ev.Type)
227+
require.NotEmpty(t, ev.Bookmark)
228+
229+
invalidBookmark := slices.Clone(ev.Bookmark)
230+
invalidBookmark[0] ^= 0xff
231+
232+
err = st.WatchKind(ctx, resource.NewMetadata(namespace, conformance.PathResourceType, "", resource.VersionUndefined), watchKindCh, state.WithKindStartFromBookmark(invalidBookmark))
233+
require.Error(t, err)
234+
require.True(t, state.IsInvalidWatchBookmarkError(err))
235+
}

pkg/state/protobuf/client/client.go

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,12 @@ func (adapter *Adapter) Watch(ctx context.Context, resourcePointer resource.Poin
340340
// receive first (empty) watch event
341341
_, err = cli.Recv()
342342
if err != nil {
343-
return err
343+
switch status.Code(err) { //nolint:exhaustive
344+
case codes.FailedPrecondition:
345+
return eInvalidWatchBookmark{err}
346+
default:
347+
return err
348+
}
344349
}
345350

346351
go adapter.watchAdapter(ctx, cli, ch, nil, opts.UnmarshalOptions.SkipProtobufUnmarshal, req)
@@ -388,7 +393,12 @@ func (adapter *Adapter) WatchKind(ctx context.Context, resourceKind resource.Kin
388393
// receive first (empty) watch event
389394
_, err = cli.Recv()
390395
if err != nil {
391-
return err
396+
switch status.Code(err) { //nolint:exhaustive
397+
case codes.FailedPrecondition:
398+
return eInvalidWatchBookmark{err}
399+
default:
400+
return err
401+
}
392402
}
393403

394404
go adapter.watchAdapter(ctx, cli, ch, nil, opts.UnmarshalOptions.SkipProtobufUnmarshal, req)
@@ -437,7 +447,12 @@ func (adapter *Adapter) WatchKindAggregated(ctx context.Context, resourceKind re
437447
// receive first (empty) watch event
438448
_, err = cli.Recv()
439449
if err != nil {
440-
return err
450+
switch status.Code(err) { //nolint:exhaustive
451+
case codes.FailedPrecondition:
452+
return eInvalidWatchBookmark{err}
453+
default:
454+
return err
455+
}
441456
}
442457

443458
go adapter.watchAdapter(ctx, cli, nil, ch, opts.UnmarshalOptions.SkipProtobufUnmarshal, req)
@@ -526,7 +541,12 @@ func (adapter *Adapter) watchAdapter(
526541

527542
_, err = cli.Recv()
528543
if err != nil {
529-
continue
544+
switch status.Code(err) { //nolint:exhaustive
545+
case codes.FailedPrecondition: // abort retries on invalid watch bookmark
546+
return nil, eInvalidWatchBookmark{err}
547+
default:
548+
continue
549+
}
530550
}
531551

532552
msg, err = cli.Recv()

pkg/state/protobuf/client/errors.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,10 @@ type ePhaseConflict struct {
3838
}
3939

4040
func (ePhaseConflict) PhaseConflictError() {}
41+
42+
//nolint:errname
43+
type eInvalidWatchBookmark struct {
44+
error
45+
}
46+
47+
func (eInvalidWatchBookmark) InvalidWatchBookmarkError() {}

pkg/state/protobuf/protobuf_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,44 @@ func TestProtobufWatchRestart(t *testing.T) {
215215
}
216216
}
217217

218+
func TestProtobufWatchInvalidBookmark(t *testing.T) {
219+
grpcConn, _, _, _ := ProtobufSetup(t) //nolint:dogsled
220+
221+
stateClient := v1alpha1.NewStateClient(grpcConn)
222+
223+
st := state.WrapCore(client.NewAdapter(stateClient,
224+
client.WithRetryLogger(zaptest.NewLogger(t)),
225+
))
226+
227+
ch := make(chan []state.Event)
228+
229+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
230+
t.Cleanup(cancel)
231+
232+
require.NoError(t, st.WatchKindAggregated(ctx, conformance.NewPathResource("test", "/foo").Metadata(), ch, state.WithBootstrapContents(true)))
233+
234+
var bookmark []byte
235+
236+
select {
237+
case <-ctx.Done():
238+
t.Fatal("timeout")
239+
case ev := <-ch:
240+
require.Len(t, ev, 1)
241+
242+
assert.Equal(t, state.Bootstrapped, ev[0].Type)
243+
assert.NotEmpty(t, ev[0].Bookmark)
244+
245+
bookmark = ev[0].Bookmark
246+
}
247+
248+
// send invalid bookmark
249+
bookmark[0] ^= 0xff
250+
251+
err := st.WatchKindAggregated(ctx, conformance.NewPathResource("test", "/foo").Metadata(), ch, state.WithKindStartFromBookmark(bookmark))
252+
require.Error(t, err)
253+
assert.True(t, state.IsInvalidWatchBookmarkError(err))
254+
}
255+
218256
func noError[T any](t *testing.T, fn func(T) error, v T, ignored ...error) {
219257
t.Helper()
220258

pkg/state/protobuf/server/server.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,12 @@ func (server *State) Watch(req *v1alpha1.WatchRequest, srv v1alpha1.State_WatchS
328328
}
329329

330330
if err != nil {
331-
return err
331+
switch {
332+
case state.IsInvalidWatchBookmarkError(err):
333+
return status.Error(codes.FailedPrecondition, err.Error())
334+
default:
335+
return err
336+
}
332337
}
333338

334339
// send empty event to signal that watch is ready

0 commit comments

Comments
 (0)